• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/CallSite.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/GlobalValue.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/MC/MCSectionELF.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <sstream>
39 
40 #undef DEBUG_TYPE
41 #define DEBUG_TYPE "nvptx-lower"
42 
43 using namespace llvm;
44 
45 static unsigned int uniqueCallSite = 0;
46 
47 static cl::opt<bool> sched4reg(
48     "nvptx-sched4reg",
49     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
50 
IsPTXVectorType(MVT VT)51 static bool IsPTXVectorType(MVT VT) {
52   switch (VT.SimpleTy) {
53   default:
54     return false;
55   case MVT::v2i1:
56   case MVT::v4i1:
57   case MVT::v2i8:
58   case MVT::v4i8:
59   case MVT::v2i16:
60   case MVT::v4i16:
61   case MVT::v2i32:
62   case MVT::v4i32:
63   case MVT::v2i64:
64   case MVT::v2f32:
65   case MVT::v4f32:
66   case MVT::v2f64:
67     return true;
68   }
69 }
70 
71 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
72 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
73 /// into their primitive components.
74 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
75 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
76 /// LowerCall, and LowerReturn.
ComputePTXValueVTs(const TargetLowering & TLI,Type * Ty,SmallVectorImpl<EVT> & ValueVTs,SmallVectorImpl<uint64_t> * Offsets=nullptr,uint64_t StartingOffset=0)77 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
78                                SmallVectorImpl<EVT> &ValueVTs,
79                                SmallVectorImpl<uint64_t> *Offsets = nullptr,
80                                uint64_t StartingOffset = 0) {
81   SmallVector<EVT, 16> TempVTs;
82   SmallVector<uint64_t, 16> TempOffsets;
83 
84   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
85   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
86     EVT VT = TempVTs[i];
87     uint64_t Off = TempOffsets[i];
88     if (VT.isVector())
89       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
90         ValueVTs.push_back(VT.getVectorElementType());
91         if (Offsets)
92           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
93       }
94     else {
95       ValueVTs.push_back(VT);
96       if (Offsets)
97         Offsets->push_back(Off);
98     }
99   }
100 }
101 
102 // NVPTXTargetLowering Constructor.
NVPTXTargetLowering(NVPTXTargetMachine & TM)103 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
104     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
105       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
106 
107   // always lower memset, memcpy, and memmove intrinsics to load/store
108   // instructions, rather
109   // then generating calls to memset, mempcy or memmove.
110   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
112   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
113 
114   setBooleanContents(ZeroOrNegativeOneBooleanContent);
115   setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
116 
117   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
118   // condition branches.
119   setJumpIsExpensive(true);
120 
121   // By default, use the Source scheduling
122   if (sched4reg)
123     setSchedulingPreference(Sched::RegPressure);
124   else
125     setSchedulingPreference(Sched::Source);
126 
127   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
128   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
129   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
130   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
131   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
132   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
133 
134   // Operations not directly supported by NVPTX.
135   setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
136   setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
137   setOperationAction(ISD::SELECT_CC, MVT::i1, Expand);
138   setOperationAction(ISD::SELECT_CC, MVT::i8, Expand);
139   setOperationAction(ISD::SELECT_CC, MVT::i16, Expand);
140   setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
141   setOperationAction(ISD::SELECT_CC, MVT::i64, Expand);
142   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
143   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
144   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
145   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
146   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
147   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
148   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
149   // Some SIGN_EXTEND_INREG can be done using cvt instruction.
150   // For others we will expand to a SHL/SRA pair.
151   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
152   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
153   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
154   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
155   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
156 
157   setOperationAction(ISD::SHL_PARTS, MVT::i32  , Custom);
158   setOperationAction(ISD::SRA_PARTS, MVT::i32  , Custom);
159   setOperationAction(ISD::SRL_PARTS, MVT::i32  , Custom);
160   setOperationAction(ISD::SHL_PARTS, MVT::i64  , Custom);
161   setOperationAction(ISD::SRA_PARTS, MVT::i64  , Custom);
162   setOperationAction(ISD::SRL_PARTS, MVT::i64  , Custom);
163 
164   if (nvptxSubtarget.hasROT64()) {
165     setOperationAction(ISD::ROTL, MVT::i64, Legal);
166     setOperationAction(ISD::ROTR, MVT::i64, Legal);
167   } else {
168     setOperationAction(ISD::ROTL, MVT::i64, Expand);
169     setOperationAction(ISD::ROTR, MVT::i64, Expand);
170   }
171   if (nvptxSubtarget.hasROT32()) {
172     setOperationAction(ISD::ROTL, MVT::i32, Legal);
173     setOperationAction(ISD::ROTR, MVT::i32, Legal);
174   } else {
175     setOperationAction(ISD::ROTL, MVT::i32, Expand);
176     setOperationAction(ISD::ROTR, MVT::i32, Expand);
177   }
178 
179   setOperationAction(ISD::ROTL, MVT::i16, Expand);
180   setOperationAction(ISD::ROTR, MVT::i16, Expand);
181   setOperationAction(ISD::ROTL, MVT::i8, Expand);
182   setOperationAction(ISD::ROTR, MVT::i8, Expand);
183   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
184   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
185   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
186 
187   // Indirect branch is not supported.
188   // This also disables Jump Table creation.
189   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
190   setOperationAction(ISD::BRIND, MVT::Other, Expand);
191 
192   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
193   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
194 
195   // We want to legalize constant related memmove and memcopy
196   // intrinsics.
197   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
198 
199   // Turn FP extload into load/fextend
200   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
201   // Turn FP truncstore into trunc + store.
202   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
203 
204   // PTX does not support load / store predicate registers
205   setOperationAction(ISD::LOAD, MVT::i1, Custom);
206   setOperationAction(ISD::STORE, MVT::i1, Custom);
207 
208   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
209   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
210   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
211   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
212   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
213   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
214 
215   // This is legal in NVPTX
216   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
217   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
218 
219   // TRAP can be lowered to PTX trap
220   setOperationAction(ISD::TRAP, MVT::Other, Legal);
221 
222   setOperationAction(ISD::ADDC, MVT::i64, Expand);
223   setOperationAction(ISD::ADDE, MVT::i64, Expand);
224 
225   // Register custom handling for vector loads/stores
226   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
227        ++i) {
228     MVT VT = (MVT::SimpleValueType) i;
229     if (IsPTXVectorType(VT)) {
230       setOperationAction(ISD::LOAD, VT, Custom);
231       setOperationAction(ISD::STORE, VT, Custom);
232       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
233     }
234   }
235 
236   // Custom handling for i8 intrinsics
237   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
238 
239   setOperationAction(ISD::CTLZ, MVT::i16, Legal);
240   setOperationAction(ISD::CTLZ, MVT::i32, Legal);
241   setOperationAction(ISD::CTLZ, MVT::i64, Legal);
242   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
243   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
244   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
245   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
246   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
247   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
248   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
249   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
250   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
251   setOperationAction(ISD::CTPOP, MVT::i16, Legal);
252   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
253   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
254 
255   // We have some custom DAG combine patterns for these nodes
256   setTargetDAGCombine(ISD::ADD);
257   setTargetDAGCombine(ISD::AND);
258   setTargetDAGCombine(ISD::FADD);
259   setTargetDAGCombine(ISD::MUL);
260   setTargetDAGCombine(ISD::SHL);
261 
262   // Now deduce the information based on the above mentioned
263   // actions
264   computeRegisterProperties();
265 }
266 
getTargetNodeName(unsigned Opcode) const267 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
268   switch (Opcode) {
269   default:
270     return nullptr;
271   case NVPTXISD::CALL:
272     return "NVPTXISD::CALL";
273   case NVPTXISD::RET_FLAG:
274     return "NVPTXISD::RET_FLAG";
275   case NVPTXISD::Wrapper:
276     return "NVPTXISD::Wrapper";
277   case NVPTXISD::DeclareParam:
278     return "NVPTXISD::DeclareParam";
279   case NVPTXISD::DeclareScalarParam:
280     return "NVPTXISD::DeclareScalarParam";
281   case NVPTXISD::DeclareRet:
282     return "NVPTXISD::DeclareRet";
283   case NVPTXISD::DeclareRetParam:
284     return "NVPTXISD::DeclareRetParam";
285   case NVPTXISD::PrintCall:
286     return "NVPTXISD::PrintCall";
287   case NVPTXISD::LoadParam:
288     return "NVPTXISD::LoadParam";
289   case NVPTXISD::LoadParamV2:
290     return "NVPTXISD::LoadParamV2";
291   case NVPTXISD::LoadParamV4:
292     return "NVPTXISD::LoadParamV4";
293   case NVPTXISD::StoreParam:
294     return "NVPTXISD::StoreParam";
295   case NVPTXISD::StoreParamV2:
296     return "NVPTXISD::StoreParamV2";
297   case NVPTXISD::StoreParamV4:
298     return "NVPTXISD::StoreParamV4";
299   case NVPTXISD::StoreParamS32:
300     return "NVPTXISD::StoreParamS32";
301   case NVPTXISD::StoreParamU32:
302     return "NVPTXISD::StoreParamU32";
303   case NVPTXISD::CallArgBegin:
304     return "NVPTXISD::CallArgBegin";
305   case NVPTXISD::CallArg:
306     return "NVPTXISD::CallArg";
307   case NVPTXISD::LastCallArg:
308     return "NVPTXISD::LastCallArg";
309   case NVPTXISD::CallArgEnd:
310     return "NVPTXISD::CallArgEnd";
311   case NVPTXISD::CallVoid:
312     return "NVPTXISD::CallVoid";
313   case NVPTXISD::CallVal:
314     return "NVPTXISD::CallVal";
315   case NVPTXISD::CallSymbol:
316     return "NVPTXISD::CallSymbol";
317   case NVPTXISD::Prototype:
318     return "NVPTXISD::Prototype";
319   case NVPTXISD::MoveParam:
320     return "NVPTXISD::MoveParam";
321   case NVPTXISD::StoreRetval:
322     return "NVPTXISD::StoreRetval";
323   case NVPTXISD::StoreRetvalV2:
324     return "NVPTXISD::StoreRetvalV2";
325   case NVPTXISD::StoreRetvalV4:
326     return "NVPTXISD::StoreRetvalV4";
327   case NVPTXISD::PseudoUseParam:
328     return "NVPTXISD::PseudoUseParam";
329   case NVPTXISD::RETURN:
330     return "NVPTXISD::RETURN";
331   case NVPTXISD::CallSeqBegin:
332     return "NVPTXISD::CallSeqBegin";
333   case NVPTXISD::CallSeqEnd:
334     return "NVPTXISD::CallSeqEnd";
335   case NVPTXISD::CallPrototype:
336     return "NVPTXISD::CallPrototype";
337   case NVPTXISD::LoadV2:
338     return "NVPTXISD::LoadV2";
339   case NVPTXISD::LoadV4:
340     return "NVPTXISD::LoadV4";
341   case NVPTXISD::LDGV2:
342     return "NVPTXISD::LDGV2";
343   case NVPTXISD::LDGV4:
344     return "NVPTXISD::LDGV4";
345   case NVPTXISD::LDUV2:
346     return "NVPTXISD::LDUV2";
347   case NVPTXISD::LDUV4:
348     return "NVPTXISD::LDUV4";
349   case NVPTXISD::StoreV2:
350     return "NVPTXISD::StoreV2";
351   case NVPTXISD::StoreV4:
352     return "NVPTXISD::StoreV4";
353   case NVPTXISD::FUN_SHFL_CLAMP:
354     return "NVPTXISD::FUN_SHFL_CLAMP";
355   case NVPTXISD::FUN_SHFR_CLAMP:
356     return "NVPTXISD::FUN_SHFR_CLAMP";
357   case NVPTXISD::IMAD:
358     return "NVPTXISD::IMAD";
359   case NVPTXISD::MUL_WIDE_SIGNED:
360     return "NVPTXISD::MUL_WIDE_SIGNED";
361   case NVPTXISD::MUL_WIDE_UNSIGNED:
362     return "NVPTXISD::MUL_WIDE_UNSIGNED";
363   case NVPTXISD::Tex1DFloatI32:        return "NVPTXISD::Tex1DFloatI32";
364   case NVPTXISD::Tex1DFloatFloat:      return "NVPTXISD::Tex1DFloatFloat";
365   case NVPTXISD::Tex1DFloatFloatLevel:
366     return "NVPTXISD::Tex1DFloatFloatLevel";
367   case NVPTXISD::Tex1DFloatFloatGrad:
368     return "NVPTXISD::Tex1DFloatFloatGrad";
369   case NVPTXISD::Tex1DI32I32:          return "NVPTXISD::Tex1DI32I32";
370   case NVPTXISD::Tex1DI32Float:        return "NVPTXISD::Tex1DI32Float";
371   case NVPTXISD::Tex1DI32FloatLevel:
372     return "NVPTXISD::Tex1DI32FloatLevel";
373   case NVPTXISD::Tex1DI32FloatGrad:
374     return "NVPTXISD::Tex1DI32FloatGrad";
375   case NVPTXISD::Tex1DArrayFloatI32:   return "NVPTXISD::Tex2DArrayFloatI32";
376   case NVPTXISD::Tex1DArrayFloatFloat: return "NVPTXISD::Tex2DArrayFloatFloat";
377   case NVPTXISD::Tex1DArrayFloatFloatLevel:
378     return "NVPTXISD::Tex2DArrayFloatFloatLevel";
379   case NVPTXISD::Tex1DArrayFloatFloatGrad:
380     return "NVPTXISD::Tex2DArrayFloatFloatGrad";
381   case NVPTXISD::Tex1DArrayI32I32:     return "NVPTXISD::Tex2DArrayI32I32";
382   case NVPTXISD::Tex1DArrayI32Float:   return "NVPTXISD::Tex2DArrayI32Float";
383   case NVPTXISD::Tex1DArrayI32FloatLevel:
384     return "NVPTXISD::Tex2DArrayI32FloatLevel";
385   case NVPTXISD::Tex1DArrayI32FloatGrad:
386     return "NVPTXISD::Tex2DArrayI32FloatGrad";
387   case NVPTXISD::Tex2DFloatI32:        return "NVPTXISD::Tex2DFloatI32";
388   case NVPTXISD::Tex2DFloatFloat:      return "NVPTXISD::Tex2DFloatFloat";
389   case NVPTXISD::Tex2DFloatFloatLevel:
390     return "NVPTXISD::Tex2DFloatFloatLevel";
391   case NVPTXISD::Tex2DFloatFloatGrad:
392     return "NVPTXISD::Tex2DFloatFloatGrad";
393   case NVPTXISD::Tex2DI32I32:          return "NVPTXISD::Tex2DI32I32";
394   case NVPTXISD::Tex2DI32Float:        return "NVPTXISD::Tex2DI32Float";
395   case NVPTXISD::Tex2DI32FloatLevel:
396     return "NVPTXISD::Tex2DI32FloatLevel";
397   case NVPTXISD::Tex2DI32FloatGrad:
398     return "NVPTXISD::Tex2DI32FloatGrad";
399   case NVPTXISD::Tex2DArrayFloatI32:   return "NVPTXISD::Tex2DArrayFloatI32";
400   case NVPTXISD::Tex2DArrayFloatFloat: return "NVPTXISD::Tex2DArrayFloatFloat";
401   case NVPTXISD::Tex2DArrayFloatFloatLevel:
402     return "NVPTXISD::Tex2DArrayFloatFloatLevel";
403   case NVPTXISD::Tex2DArrayFloatFloatGrad:
404     return "NVPTXISD::Tex2DArrayFloatFloatGrad";
405   case NVPTXISD::Tex2DArrayI32I32:     return "NVPTXISD::Tex2DArrayI32I32";
406   case NVPTXISD::Tex2DArrayI32Float:   return "NVPTXISD::Tex2DArrayI32Float";
407   case NVPTXISD::Tex2DArrayI32FloatLevel:
408     return "NVPTXISD::Tex2DArrayI32FloatLevel";
409   case NVPTXISD::Tex2DArrayI32FloatGrad:
410     return "NVPTXISD::Tex2DArrayI32FloatGrad";
411   case NVPTXISD::Tex3DFloatI32:        return "NVPTXISD::Tex3DFloatI32";
412   case NVPTXISD::Tex3DFloatFloat:      return "NVPTXISD::Tex3DFloatFloat";
413   case NVPTXISD::Tex3DFloatFloatLevel:
414     return "NVPTXISD::Tex3DFloatFloatLevel";
415   case NVPTXISD::Tex3DFloatFloatGrad:
416     return "NVPTXISD::Tex3DFloatFloatGrad";
417   case NVPTXISD::Tex3DI32I32:          return "NVPTXISD::Tex3DI32I32";
418   case NVPTXISD::Tex3DI32Float:        return "NVPTXISD::Tex3DI32Float";
419   case NVPTXISD::Tex3DI32FloatLevel:
420     return "NVPTXISD::Tex3DI32FloatLevel";
421   case NVPTXISD::Tex3DI32FloatGrad:
422     return "NVPTXISD::Tex3DI32FloatGrad";
423 
424   case NVPTXISD::Suld1DI8Trap:          return "NVPTXISD::Suld1DI8Trap";
425   case NVPTXISD::Suld1DI16Trap:         return "NVPTXISD::Suld1DI16Trap";
426   case NVPTXISD::Suld1DI32Trap:         return "NVPTXISD::Suld1DI32Trap";
427   case NVPTXISD::Suld1DV2I8Trap:        return "NVPTXISD::Suld1DV2I8Trap";
428   case NVPTXISD::Suld1DV2I16Trap:       return "NVPTXISD::Suld1DV2I16Trap";
429   case NVPTXISD::Suld1DV2I32Trap:       return "NVPTXISD::Suld1DV2I32Trap";
430   case NVPTXISD::Suld1DV4I8Trap:        return "NVPTXISD::Suld1DV4I8Trap";
431   case NVPTXISD::Suld1DV4I16Trap:       return "NVPTXISD::Suld1DV4I16Trap";
432   case NVPTXISD::Suld1DV4I32Trap:       return "NVPTXISD::Suld1DV4I32Trap";
433 
434   case NVPTXISD::Suld1DArrayI8Trap:     return "NVPTXISD::Suld1DArrayI8Trap";
435   case NVPTXISD::Suld1DArrayI16Trap:    return "NVPTXISD::Suld1DArrayI16Trap";
436   case NVPTXISD::Suld1DArrayI32Trap:    return "NVPTXISD::Suld1DArrayI32Trap";
437   case NVPTXISD::Suld1DArrayV2I8Trap:   return "NVPTXISD::Suld1DArrayV2I8Trap";
438   case NVPTXISD::Suld1DArrayV2I16Trap:  return "NVPTXISD::Suld1DArrayV2I16Trap";
439   case NVPTXISD::Suld1DArrayV2I32Trap:  return "NVPTXISD::Suld1DArrayV2I32Trap";
440   case NVPTXISD::Suld1DArrayV4I8Trap:   return "NVPTXISD::Suld1DArrayV4I8Trap";
441   case NVPTXISD::Suld1DArrayV4I16Trap:  return "NVPTXISD::Suld1DArrayV4I16Trap";
442   case NVPTXISD::Suld1DArrayV4I32Trap:  return "NVPTXISD::Suld1DArrayV4I32Trap";
443 
444   case NVPTXISD::Suld2DI8Trap:          return "NVPTXISD::Suld2DI8Trap";
445   case NVPTXISD::Suld2DI16Trap:         return "NVPTXISD::Suld2DI16Trap";
446   case NVPTXISD::Suld2DI32Trap:         return "NVPTXISD::Suld2DI32Trap";
447   case NVPTXISD::Suld2DV2I8Trap:        return "NVPTXISD::Suld2DV2I8Trap";
448   case NVPTXISD::Suld2DV2I16Trap:       return "NVPTXISD::Suld2DV2I16Trap";
449   case NVPTXISD::Suld2DV2I32Trap:       return "NVPTXISD::Suld2DV2I32Trap";
450   case NVPTXISD::Suld2DV4I8Trap:        return "NVPTXISD::Suld2DV4I8Trap";
451   case NVPTXISD::Suld2DV4I16Trap:       return "NVPTXISD::Suld2DV4I16Trap";
452   case NVPTXISD::Suld2DV4I32Trap:       return "NVPTXISD::Suld2DV4I32Trap";
453 
454   case NVPTXISD::Suld2DArrayI8Trap:     return "NVPTXISD::Suld2DArrayI8Trap";
455   case NVPTXISD::Suld2DArrayI16Trap:    return "NVPTXISD::Suld2DArrayI16Trap";
456   case NVPTXISD::Suld2DArrayI32Trap:    return "NVPTXISD::Suld2DArrayI32Trap";
457   case NVPTXISD::Suld2DArrayV2I8Trap:   return "NVPTXISD::Suld2DArrayV2I8Trap";
458   case NVPTXISD::Suld2DArrayV2I16Trap:  return "NVPTXISD::Suld2DArrayV2I16Trap";
459   case NVPTXISD::Suld2DArrayV2I32Trap:  return "NVPTXISD::Suld2DArrayV2I32Trap";
460   case NVPTXISD::Suld2DArrayV4I8Trap:   return "NVPTXISD::Suld2DArrayV4I8Trap";
461   case NVPTXISD::Suld2DArrayV4I16Trap:  return "NVPTXISD::Suld2DArrayV4I16Trap";
462   case NVPTXISD::Suld2DArrayV4I32Trap:  return "NVPTXISD::Suld2DArrayV4I32Trap";
463 
464   case NVPTXISD::Suld3DI8Trap:          return "NVPTXISD::Suld3DI8Trap";
465   case NVPTXISD::Suld3DI16Trap:         return "NVPTXISD::Suld3DI16Trap";
466   case NVPTXISD::Suld3DI32Trap:         return "NVPTXISD::Suld3DI32Trap";
467   case NVPTXISD::Suld3DV2I8Trap:        return "NVPTXISD::Suld3DV2I8Trap";
468   case NVPTXISD::Suld3DV2I16Trap:       return "NVPTXISD::Suld3DV2I16Trap";
469   case NVPTXISD::Suld3DV2I32Trap:       return "NVPTXISD::Suld3DV2I32Trap";
470   case NVPTXISD::Suld3DV4I8Trap:        return "NVPTXISD::Suld3DV4I8Trap";
471   case NVPTXISD::Suld3DV4I16Trap:       return "NVPTXISD::Suld3DV4I16Trap";
472   case NVPTXISD::Suld3DV4I32Trap:       return "NVPTXISD::Suld3DV4I32Trap";
473   }
474 }
475 
476 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(EVT VT) const477 NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const {
478   if (VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1)
479     return TypeSplitVector;
480 
481   return TargetLoweringBase::getPreferredVectorAction(VT);
482 }
483 
484 SDValue
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const485 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
486   SDLoc dl(Op);
487   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
488   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
489   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
490 }
491 
492 std::string
getPrototype(Type * retTy,const ArgListTy & Args,const SmallVectorImpl<ISD::OutputArg> & Outs,unsigned retAlignment,const ImmutableCallSite * CS) const493 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
494                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
495                                   unsigned retAlignment,
496                                   const ImmutableCallSite *CS) const {
497 
498   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
499   assert(isABI && "Non-ABI compilation is not supported");
500   if (!isABI)
501     return "";
502 
503   std::stringstream O;
504   O << "prototype_" << uniqueCallSite << " : .callprototype ";
505 
506   if (retTy->getTypeID() == Type::VoidTyID) {
507     O << "()";
508   } else {
509     O << "(";
510     if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
511       unsigned size = 0;
512       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
513         size = ITy->getBitWidth();
514         if (size < 32)
515           size = 32;
516       } else {
517         assert(retTy->isFloatingPointTy() &&
518                "Floating point type expected here");
519         size = retTy->getPrimitiveSizeInBits();
520       }
521 
522       O << ".param .b" << size << " _";
523     } else if (isa<PointerType>(retTy)) {
524       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
525     } else {
526       if((retTy->getTypeID() == Type::StructTyID) ||
527          isa<VectorType>(retTy)) {
528         O << ".param .align "
529           << retAlignment
530           << " .b8 _["
531           << getDataLayout()->getTypeAllocSize(retTy) << "]";
532       } else {
533         assert(false && "Unknown return type");
534       }
535     }
536     O << ") ";
537   }
538   O << "_ (";
539 
540   bool first = true;
541   MVT thePointerTy = getPointerTy();
542 
543   unsigned OIdx = 0;
544   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
545     Type *Ty = Args[i].Ty;
546     if (!first) {
547       O << ", ";
548     }
549     first = false;
550 
551     if (Outs[OIdx].Flags.isByVal() == false) {
552       if (Ty->isAggregateType() || Ty->isVectorTy()) {
553         unsigned align = 0;
554         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
555         const DataLayout *TD = getDataLayout();
556         // +1 because index 0 is reserved for return type alignment
557         if (!llvm::getAlign(*CallI, i + 1, align))
558           align = TD->getABITypeAlignment(Ty);
559         unsigned sz = TD->getTypeAllocSize(Ty);
560         O << ".param .align " << align << " .b8 ";
561         O << "_";
562         O << "[" << sz << "]";
563         // update the index for Outs
564         SmallVector<EVT, 16> vtparts;
565         ComputeValueVTs(*this, Ty, vtparts);
566         if (unsigned len = vtparts.size())
567           OIdx += len - 1;
568         continue;
569       }
570        // i8 types in IR will be i16 types in SDAG
571       assert((getValueType(Ty) == Outs[OIdx].VT ||
572              (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
573              "type mismatch between callee prototype and arguments");
574       // scalar type
575       unsigned sz = 0;
576       if (isa<IntegerType>(Ty)) {
577         sz = cast<IntegerType>(Ty)->getBitWidth();
578         if (sz < 32)
579           sz = 32;
580       } else if (isa<PointerType>(Ty))
581         sz = thePointerTy.getSizeInBits();
582       else
583         sz = Ty->getPrimitiveSizeInBits();
584       O << ".param .b" << sz << " ";
585       O << "_";
586       continue;
587     }
588     const PointerType *PTy = dyn_cast<PointerType>(Ty);
589     assert(PTy && "Param with byval attribute should be a pointer type");
590     Type *ETy = PTy->getElementType();
591 
592     unsigned align = Outs[OIdx].Flags.getByValAlign();
593     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
594     O << ".param .align " << align << " .b8 ";
595     O << "_";
596     O << "[" << sz << "]";
597   }
598   O << ");";
599   return O.str();
600 }
601 
602 unsigned
getArgumentAlignment(SDValue Callee,const ImmutableCallSite * CS,Type * Ty,unsigned Idx) const603 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
604                                           const ImmutableCallSite *CS,
605                                           Type *Ty,
606                                           unsigned Idx) const {
607   const DataLayout *TD = getDataLayout();
608   unsigned Align = 0;
609   const Value *DirectCallee = CS->getCalledFunction();
610 
611   if (!DirectCallee) {
612     // We don't have a direct function symbol, but that may be because of
613     // constant cast instructions in the call.
614     const Instruction *CalleeI = CS->getInstruction();
615     assert(CalleeI && "Call target is not a function or derived value?");
616 
617     // With bitcast'd call targets, the instruction will be the call
618     if (isa<CallInst>(CalleeI)) {
619       // Check if we have call alignment metadata
620       if (llvm::getAlign(*cast<CallInst>(CalleeI), Idx, Align))
621         return Align;
622 
623       const Value *CalleeV = cast<CallInst>(CalleeI)->getCalledValue();
624       // Ignore any bitcast instructions
625       while(isa<ConstantExpr>(CalleeV)) {
626         const ConstantExpr *CE = cast<ConstantExpr>(CalleeV);
627         if (!CE->isCast())
628           break;
629         // Look through the bitcast
630         CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0);
631       }
632 
633       // We have now looked past all of the bitcasts.  Do we finally have a
634       // Function?
635       if (isa<Function>(CalleeV))
636         DirectCallee = CalleeV;
637     }
638   }
639 
640   // Check for function alignment information if we found that the
641   // ultimate target is a Function
642   if (DirectCallee)
643     if (llvm::getAlign(*cast<Function>(DirectCallee), Idx, Align))
644       return Align;
645 
646   // Call is indirect or alignment information is not available, fall back to
647   // the ABI type alignment
648   return TD->getABITypeAlignment(Ty);
649 }
650 
LowerCall(TargetLowering::CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const651 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
652                                        SmallVectorImpl<SDValue> &InVals) const {
653   SelectionDAG &DAG = CLI.DAG;
654   SDLoc dl = CLI.DL;
655   SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
656   SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
657   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
658   SDValue Chain = CLI.Chain;
659   SDValue Callee = CLI.Callee;
660   bool &isTailCall = CLI.IsTailCall;
661   ArgListTy &Args = CLI.getArgs();
662   Type *retTy = CLI.RetTy;
663   ImmutableCallSite *CS = CLI.CS;
664 
665   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
666   assert(isABI && "Non-ABI compilation is not supported");
667   if (!isABI)
668     return Chain;
669   const DataLayout *TD = getDataLayout();
670   MachineFunction &MF = DAG.getMachineFunction();
671   const Function *F = MF.getFunction();
672 
673   SDValue tempChain = Chain;
674   Chain =
675       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
676                            dl);
677   SDValue InFlag = Chain.getValue(1);
678 
679   unsigned paramCount = 0;
680   // Args.size() and Outs.size() need not match.
681   // Outs.size() will be larger
682   //   * if there is an aggregate argument with multiple fields (each field
683   //     showing up separately in Outs)
684   //   * if there is a vector argument with more than typical vector-length
685   //     elements (generally if more than 4) where each vector element is
686   //     individually present in Outs.
687   // So a different index should be used for indexing into Outs/OutVals.
688   // See similar issue in LowerFormalArguments.
689   unsigned OIdx = 0;
690   // Declare the .params or .reg need to pass values
691   // to the function
692   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
693     EVT VT = Outs[OIdx].VT;
694     Type *Ty = Args[i].Ty;
695 
696     if (Outs[OIdx].Flags.isByVal() == false) {
697       if (Ty->isAggregateType()) {
698         // aggregate
699         SmallVector<EVT, 16> vtparts;
700         SmallVector<uint64_t, 16> Offsets;
701         ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
702 
703         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
704         // declare .param .align <align> .b8 .param<n>[<size>];
705         unsigned sz = TD->getTypeAllocSize(Ty);
706         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
707         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
708                                       DAG.getConstant(paramCount, MVT::i32),
709                                       DAG.getConstant(sz, MVT::i32), InFlag };
710         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
711                             DeclareParamOps);
712         InFlag = Chain.getValue(1);
713         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
714           EVT elemtype = vtparts[j];
715           unsigned ArgAlign = GreatestCommonDivisor64(align, Offsets[j]);
716           if (elemtype.isInteger() && (sz < 8))
717             sz = 8;
718           SDValue StVal = OutVals[OIdx];
719           if (elemtype.getSizeInBits() < 16) {
720             StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
721           }
722           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
723           SDValue CopyParamOps[] = { Chain,
724                                      DAG.getConstant(paramCount, MVT::i32),
725                                      DAG.getConstant(Offsets[j], MVT::i32),
726                                      StVal, InFlag };
727           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
728                                           CopyParamVTs, CopyParamOps,
729                                           elemtype, MachinePointerInfo(),
730                                           ArgAlign);
731           InFlag = Chain.getValue(1);
732           ++OIdx;
733         }
734         if (vtparts.size() > 0)
735           --OIdx;
736         ++paramCount;
737         continue;
738       }
739       if (Ty->isVectorTy()) {
740         EVT ObjectVT = getValueType(Ty);
741         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
742         // declare .param .align <align> .b8 .param<n>[<size>];
743         unsigned sz = TD->getTypeAllocSize(Ty);
744         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
745         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
746                                       DAG.getConstant(paramCount, MVT::i32),
747                                       DAG.getConstant(sz, MVT::i32), InFlag };
748         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
749                             DeclareParamOps);
750         InFlag = Chain.getValue(1);
751         unsigned NumElts = ObjectVT.getVectorNumElements();
752         EVT EltVT = ObjectVT.getVectorElementType();
753         EVT MemVT = EltVT;
754         bool NeedExtend = false;
755         if (EltVT.getSizeInBits() < 16) {
756           NeedExtend = true;
757           EltVT = MVT::i16;
758         }
759 
760         // V1 store
761         if (NumElts == 1) {
762           SDValue Elt = OutVals[OIdx++];
763           if (NeedExtend)
764             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
765 
766           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
767           SDValue CopyParamOps[] = { Chain,
768                                      DAG.getConstant(paramCount, MVT::i32),
769                                      DAG.getConstant(0, MVT::i32), Elt,
770                                      InFlag };
771           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
772                                           CopyParamVTs, CopyParamOps,
773                                           MemVT, MachinePointerInfo());
774           InFlag = Chain.getValue(1);
775         } else if (NumElts == 2) {
776           SDValue Elt0 = OutVals[OIdx++];
777           SDValue Elt1 = OutVals[OIdx++];
778           if (NeedExtend) {
779             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
780             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
781           }
782 
783           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
784           SDValue CopyParamOps[] = { Chain,
785                                      DAG.getConstant(paramCount, MVT::i32),
786                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
787                                      InFlag };
788           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
789                                           CopyParamVTs, CopyParamOps,
790                                           MemVT, MachinePointerInfo());
791           InFlag = Chain.getValue(1);
792         } else {
793           unsigned curOffset = 0;
794           // V4 stores
795           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
796           // the
797           // vector will be expanded to a power of 2 elements, so we know we can
798           // always round up to the next multiple of 4 when creating the vector
799           // stores.
800           // e.g.  4 elem => 1 st.v4
801           //       6 elem => 2 st.v4
802           //       8 elem => 2 st.v4
803           //      11 elem => 3 st.v4
804           unsigned VecSize = 4;
805           if (EltVT.getSizeInBits() == 64)
806             VecSize = 2;
807 
808           // This is potentially only part of a vector, so assume all elements
809           // are packed together.
810           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
811 
812           for (unsigned i = 0; i < NumElts; i += VecSize) {
813             // Get values
814             SDValue StoreVal;
815             SmallVector<SDValue, 8> Ops;
816             Ops.push_back(Chain);
817             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
818             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
819 
820             unsigned Opc = NVPTXISD::StoreParamV2;
821 
822             StoreVal = OutVals[OIdx++];
823             if (NeedExtend)
824               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
825             Ops.push_back(StoreVal);
826 
827             if (i + 1 < NumElts) {
828               StoreVal = OutVals[OIdx++];
829               if (NeedExtend)
830                 StoreVal =
831                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
832             } else {
833               StoreVal = DAG.getUNDEF(EltVT);
834             }
835             Ops.push_back(StoreVal);
836 
837             if (VecSize == 4) {
838               Opc = NVPTXISD::StoreParamV4;
839               if (i + 2 < NumElts) {
840                 StoreVal = OutVals[OIdx++];
841                 if (NeedExtend)
842                   StoreVal =
843                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
844               } else {
845                 StoreVal = DAG.getUNDEF(EltVT);
846               }
847               Ops.push_back(StoreVal);
848 
849               if (i + 3 < NumElts) {
850                 StoreVal = OutVals[OIdx++];
851                 if (NeedExtend)
852                   StoreVal =
853                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
854               } else {
855                 StoreVal = DAG.getUNDEF(EltVT);
856               }
857               Ops.push_back(StoreVal);
858             }
859 
860             Ops.push_back(InFlag);
861 
862             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
863             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, Ops,
864                                             MemVT, MachinePointerInfo());
865             InFlag = Chain.getValue(1);
866             curOffset += PerStoreOffset;
867           }
868         }
869         ++paramCount;
870         --OIdx;
871         continue;
872       }
873       // Plain scalar
874       // for ABI,    declare .param .b<size> .param<n>;
875       unsigned sz = VT.getSizeInBits();
876       bool needExtend = false;
877       if (VT.isInteger()) {
878         if (sz < 16)
879           needExtend = true;
880         if (sz < 32)
881           sz = 32;
882       }
883       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
884       SDValue DeclareParamOps[] = { Chain,
885                                     DAG.getConstant(paramCount, MVT::i32),
886                                     DAG.getConstant(sz, MVT::i32),
887                                     DAG.getConstant(0, MVT::i32), InFlag };
888       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
889                           DeclareParamOps);
890       InFlag = Chain.getValue(1);
891       SDValue OutV = OutVals[OIdx];
892       if (needExtend) {
893         // zext/sext i1 to i16
894         unsigned opc = ISD::ZERO_EXTEND;
895         if (Outs[OIdx].Flags.isSExt())
896           opc = ISD::SIGN_EXTEND;
897         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
898       }
899       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
900       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
901                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
902 
903       unsigned opcode = NVPTXISD::StoreParam;
904       if (Outs[OIdx].Flags.isZExt())
905         opcode = NVPTXISD::StoreParamU32;
906       else if (Outs[OIdx].Flags.isSExt())
907         opcode = NVPTXISD::StoreParamS32;
908       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps,
909                                       VT, MachinePointerInfo());
910 
911       InFlag = Chain.getValue(1);
912       ++paramCount;
913       continue;
914     }
915     // struct or vector
916     SmallVector<EVT, 16> vtparts;
917     SmallVector<uint64_t, 16> Offsets;
918     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
919     assert(PTy && "Type of a byval parameter should be pointer");
920     ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
921 
922     // declare .param .align <align> .b8 .param<n>[<size>];
923     unsigned sz = Outs[OIdx].Flags.getByValSize();
924     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
925     unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
926     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
927     // so we don't need to worry about natural alignment or not.
928     // See TargetLowering::LowerCallTo().
929     SDValue DeclareParamOps[] = {
930       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
931       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
932       InFlag
933     };
934     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
935                         DeclareParamOps);
936     InFlag = Chain.getValue(1);
937     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
938       EVT elemtype = vtparts[j];
939       int curOffset = Offsets[j];
940       unsigned PartAlign = GreatestCommonDivisor64(ArgAlign, curOffset);
941       SDValue srcAddr =
942           DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
943                       DAG.getConstant(curOffset, getPointerTy()));
944       SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
945                                    MachinePointerInfo(), false, false, false,
946                                    PartAlign);
947       if (elemtype.getSizeInBits() < 16) {
948         theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
949       }
950       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
951       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
952                                  DAG.getConstant(curOffset, MVT::i32), theVal,
953                                  InFlag };
954       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
955                                       CopyParamOps, elemtype,
956                                       MachinePointerInfo());
957 
958       InFlag = Chain.getValue(1);
959     }
960     ++paramCount;
961   }
962 
963   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
964   unsigned retAlignment = 0;
965 
966   // Handle Result
967   if (Ins.size() > 0) {
968     SmallVector<EVT, 16> resvtparts;
969     ComputeValueVTs(*this, retTy, resvtparts);
970 
971     // Declare
972     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
973     //  .param .b<size-in-bits> retval0
974     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
975     if (retTy->isSingleValueType()) {
976       // Scalar needs to be at least 32bit wide
977       if (resultsz < 32)
978         resultsz = 32;
979       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
980       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
981                                   DAG.getConstant(resultsz, MVT::i32),
982                                   DAG.getConstant(0, MVT::i32), InFlag };
983       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
984                           DeclareRetOps);
985       InFlag = Chain.getValue(1);
986     } else {
987       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
988       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
989       SDValue DeclareRetOps[] = { Chain,
990                                   DAG.getConstant(retAlignment, MVT::i32),
991                                   DAG.getConstant(resultsz / 8, MVT::i32),
992                                   DAG.getConstant(0, MVT::i32), InFlag };
993       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
994                           DeclareRetOps);
995       InFlag = Chain.getValue(1);
996     }
997   }
998 
999   if (!Func) {
1000     // This is indirect function call case : PTX requires a prototype of the
1001     // form
1002     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1003     // to be emitted, and the label has to used as the last arg of call
1004     // instruction.
1005     // The prototype is embedded in a string and put as the operand for a
1006     // CallPrototype SDNode which will print out to the value of the string.
1007     SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1008     std::string Proto = getPrototype(retTy, Args, Outs, retAlignment, CS);
1009     const char *ProtoStr =
1010       nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
1011     SDValue ProtoOps[] = {
1012       Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InFlag,
1013     };
1014     Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1015     InFlag = Chain.getValue(1);
1016   }
1017   // Op to just print "call"
1018   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1019   SDValue PrintCallOps[] = {
1020     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1021   };
1022   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1023                       dl, PrintCallVTs, PrintCallOps);
1024   InFlag = Chain.getValue(1);
1025 
1026   // Ops to print out the function name
1027   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1028   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1029   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
1030   InFlag = Chain.getValue(1);
1031 
1032   // Ops to print out the param list
1033   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1034   SDValue CallArgBeginOps[] = { Chain, InFlag };
1035   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1036                       CallArgBeginOps);
1037   InFlag = Chain.getValue(1);
1038 
1039   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1040     unsigned opcode;
1041     if (i == (e - 1))
1042       opcode = NVPTXISD::LastCallArg;
1043     else
1044       opcode = NVPTXISD::CallArg;
1045     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1046     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1047                              DAG.getConstant(i, MVT::i32), InFlag };
1048     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
1049     InFlag = Chain.getValue(1);
1050   }
1051   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1052   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1053                               InFlag };
1054   Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
1055   InFlag = Chain.getValue(1);
1056 
1057   if (!Func) {
1058     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1059     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1060                                InFlag };
1061     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
1062     InFlag = Chain.getValue(1);
1063   }
1064 
1065   // Generate loads from param memory/moves from registers for result
1066   if (Ins.size() > 0) {
1067     if (retTy && retTy->isVectorTy()) {
1068       EVT ObjectVT = getValueType(retTy);
1069       unsigned NumElts = ObjectVT.getVectorNumElements();
1070       EVT EltVT = ObjectVT.getVectorElementType();
1071       assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
1072                                                         ObjectVT) == NumElts &&
1073              "Vector was not scalarized");
1074       unsigned sz = EltVT.getSizeInBits();
1075       bool needTruncate = sz < 8 ? true : false;
1076 
1077       if (NumElts == 1) {
1078         // Just a simple load
1079         SmallVector<EVT, 4> LoadRetVTs;
1080         if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1081           // If loading i1/i8 result, generate
1082           //   load.b8 i16
1083           //   if i1
1084           //   trunc i16 to i1
1085           LoadRetVTs.push_back(MVT::i16);
1086         } else
1087           LoadRetVTs.push_back(EltVT);
1088         LoadRetVTs.push_back(MVT::Other);
1089         LoadRetVTs.push_back(MVT::Glue);
1090         SmallVector<SDValue, 4> LoadRetOps;
1091         LoadRetOps.push_back(Chain);
1092         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1093         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1094         LoadRetOps.push_back(InFlag);
1095         SDValue retval = DAG.getMemIntrinsicNode(
1096             NVPTXISD::LoadParam, dl,
1097             DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo());
1098         Chain = retval.getValue(1);
1099         InFlag = retval.getValue(2);
1100         SDValue Ret0 = retval;
1101         if (needTruncate)
1102           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1103         InVals.push_back(Ret0);
1104       } else if (NumElts == 2) {
1105         // LoadV2
1106         SmallVector<EVT, 4> LoadRetVTs;
1107         if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1108           // If loading i1/i8 result, generate
1109           //   load.b8 i16
1110           //   if i1
1111           //   trunc i16 to i1
1112           LoadRetVTs.push_back(MVT::i16);
1113           LoadRetVTs.push_back(MVT::i16);
1114         } else {
1115           LoadRetVTs.push_back(EltVT);
1116           LoadRetVTs.push_back(EltVT);
1117         }
1118         LoadRetVTs.push_back(MVT::Other);
1119         LoadRetVTs.push_back(MVT::Glue);
1120         SmallVector<SDValue, 4> LoadRetOps;
1121         LoadRetOps.push_back(Chain);
1122         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1123         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1124         LoadRetOps.push_back(InFlag);
1125         SDValue retval = DAG.getMemIntrinsicNode(
1126             NVPTXISD::LoadParamV2, dl,
1127             DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo());
1128         Chain = retval.getValue(2);
1129         InFlag = retval.getValue(3);
1130         SDValue Ret0 = retval.getValue(0);
1131         SDValue Ret1 = retval.getValue(1);
1132         if (needTruncate) {
1133           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1134           InVals.push_back(Ret0);
1135           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1136           InVals.push_back(Ret1);
1137         } else {
1138           InVals.push_back(Ret0);
1139           InVals.push_back(Ret1);
1140         }
1141       } else {
1142         // Split into N LoadV4
1143         unsigned Ofst = 0;
1144         unsigned VecSize = 4;
1145         unsigned Opc = NVPTXISD::LoadParamV4;
1146         if (EltVT.getSizeInBits() == 64) {
1147           VecSize = 2;
1148           Opc = NVPTXISD::LoadParamV2;
1149         }
1150         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1151         for (unsigned i = 0; i < NumElts; i += VecSize) {
1152           SmallVector<EVT, 8> LoadRetVTs;
1153           if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1154             // If loading i1/i8 result, generate
1155             //   load.b8 i16
1156             //   if i1
1157             //   trunc i16 to i1
1158             for (unsigned j = 0; j < VecSize; ++j)
1159               LoadRetVTs.push_back(MVT::i16);
1160           } else {
1161             for (unsigned j = 0; j < VecSize; ++j)
1162               LoadRetVTs.push_back(EltVT);
1163           }
1164           LoadRetVTs.push_back(MVT::Other);
1165           LoadRetVTs.push_back(MVT::Glue);
1166           SmallVector<SDValue, 4> LoadRetOps;
1167           LoadRetOps.push_back(Chain);
1168           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1169           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1170           LoadRetOps.push_back(InFlag);
1171           SDValue retval = DAG.getMemIntrinsicNode(
1172               Opc, dl, DAG.getVTList(LoadRetVTs),
1173               LoadRetOps, EltVT, MachinePointerInfo());
1174           if (VecSize == 2) {
1175             Chain = retval.getValue(2);
1176             InFlag = retval.getValue(3);
1177           } else {
1178             Chain = retval.getValue(4);
1179             InFlag = retval.getValue(5);
1180           }
1181 
1182           for (unsigned j = 0; j < VecSize; ++j) {
1183             if (i + j >= NumElts)
1184               break;
1185             SDValue Elt = retval.getValue(j);
1186             if (needTruncate)
1187               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1188             InVals.push_back(Elt);
1189           }
1190           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1191         }
1192       }
1193     } else {
1194       SmallVector<EVT, 16> VTs;
1195       SmallVector<uint64_t, 16> Offsets;
1196       ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
1197       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1198       unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
1199       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1200         unsigned sz = VTs[i].getSizeInBits();
1201         unsigned AlignI = GreatestCommonDivisor64(RetAlign, Offsets[i]);
1202         bool needTruncate = sz < 8 ? true : false;
1203         if (VTs[i].isInteger() && (sz < 8))
1204           sz = 8;
1205 
1206         SmallVector<EVT, 4> LoadRetVTs;
1207         EVT TheLoadType = VTs[i];
1208         if (retTy->isIntegerTy() &&
1209             TD->getTypeAllocSizeInBits(retTy) < 32) {
1210           // This is for integer types only, and specifically not for
1211           // aggregates.
1212           LoadRetVTs.push_back(MVT::i32);
1213           TheLoadType = MVT::i32;
1214         } else if (sz < 16) {
1215           // If loading i1/i8 result, generate
1216           //   load i8 (-> i16)
1217           //   trunc i16 to i1/i8
1218           LoadRetVTs.push_back(MVT::i16);
1219         } else
1220           LoadRetVTs.push_back(Ins[i].VT);
1221         LoadRetVTs.push_back(MVT::Other);
1222         LoadRetVTs.push_back(MVT::Glue);
1223 
1224         SmallVector<SDValue, 4> LoadRetOps;
1225         LoadRetOps.push_back(Chain);
1226         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1227         LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
1228         LoadRetOps.push_back(InFlag);
1229         SDValue retval = DAG.getMemIntrinsicNode(
1230             NVPTXISD::LoadParam, dl,
1231             DAG.getVTList(LoadRetVTs), LoadRetOps,
1232             TheLoadType, MachinePointerInfo(), AlignI);
1233         Chain = retval.getValue(1);
1234         InFlag = retval.getValue(2);
1235         SDValue Ret0 = retval.getValue(0);
1236         if (needTruncate)
1237           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1238         InVals.push_back(Ret0);
1239       }
1240     }
1241   }
1242 
1243   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1244                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1245                              InFlag, dl);
1246   uniqueCallSite++;
1247 
1248   // set isTailCall to false for now, until we figure out how to express
1249   // tail call optimization in PTX
1250   isTailCall = false;
1251   return Chain;
1252 }
1253 
1254 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1255 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1256 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1257 SDValue
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const1258 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1259   SDNode *Node = Op.getNode();
1260   SDLoc dl(Node);
1261   SmallVector<SDValue, 8> Ops;
1262   unsigned NumOperands = Node->getNumOperands();
1263   for (unsigned i = 0; i < NumOperands; ++i) {
1264     SDValue SubOp = Node->getOperand(i);
1265     EVT VVT = SubOp.getNode()->getValueType(0);
1266     EVT EltVT = VVT.getVectorElementType();
1267     unsigned NumSubElem = VVT.getVectorNumElements();
1268     for (unsigned j = 0; j < NumSubElem; ++j) {
1269       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1270                                 DAG.getIntPtrConstant(j)));
1271     }
1272   }
1273   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), Ops);
1274 }
1275 
1276 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
1277 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
1278 ///    amount, or
1279 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
1280 ///    amount.
LowerShiftRightParts(SDValue Op,SelectionDAG & DAG) const1281 SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
1282                                                   SelectionDAG &DAG) const {
1283   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
1284   assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
1285 
1286   EVT VT = Op.getValueType();
1287   unsigned VTBits = VT.getSizeInBits();
1288   SDLoc dl(Op);
1289   SDValue ShOpLo = Op.getOperand(0);
1290   SDValue ShOpHi = Op.getOperand(1);
1291   SDValue ShAmt  = Op.getOperand(2);
1292   unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
1293 
1294   if (VTBits == 32 && nvptxSubtarget.getSmVersion() >= 35) {
1295 
1296     // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
1297     // {dHi, dLo} = {aHi, aLo} >> Amt
1298     //   dHi = aHi >> Amt
1299     //   dLo = shf.r.clamp aLo, aHi, Amt
1300 
1301     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
1302     SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
1303                              ShAmt);
1304 
1305     SDValue Ops[2] = { Lo, Hi };
1306     return DAG.getMergeValues(Ops, dl);
1307   }
1308   else {
1309 
1310     // {dHi, dLo} = {aHi, aLo} >> Amt
1311     // - if (Amt>=size) then
1312     //      dLo = aHi >> (Amt-size)
1313     //      dHi = aHi >> Amt (this is either all 0 or all 1)
1314     //   else
1315     //      dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
1316     //      dHi = aHi >> Amt
1317 
1318     SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
1319                                    DAG.getConstant(VTBits, MVT::i32), ShAmt);
1320     SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
1321     SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
1322                                      DAG.getConstant(VTBits, MVT::i32));
1323     SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
1324     SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
1325     SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
1326 
1327     SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
1328                                DAG.getConstant(VTBits, MVT::i32), ISD::SETGE);
1329     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
1330     SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
1331 
1332     SDValue Ops[2] = { Lo, Hi };
1333     return DAG.getMergeValues(Ops, dl);
1334   }
1335 }
1336 
1337 /// LowerShiftLeftParts - Lower SHL_PARTS, which
1338 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
1339 ///    amount, or
1340 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
1341 ///    amount.
LowerShiftLeftParts(SDValue Op,SelectionDAG & DAG) const1342 SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
1343                                                  SelectionDAG &DAG) const {
1344   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
1345   assert(Op.getOpcode() == ISD::SHL_PARTS);
1346 
1347   EVT VT = Op.getValueType();
1348   unsigned VTBits = VT.getSizeInBits();
1349   SDLoc dl(Op);
1350   SDValue ShOpLo = Op.getOperand(0);
1351   SDValue ShOpHi = Op.getOperand(1);
1352   SDValue ShAmt  = Op.getOperand(2);
1353 
1354   if (VTBits == 32 && nvptxSubtarget.getSmVersion() >= 35) {
1355 
1356     // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
1357     // {dHi, dLo} = {aHi, aLo} << Amt
1358     //   dHi = shf.l.clamp aLo, aHi, Amt
1359     //   dLo = aLo << Amt
1360 
1361     SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
1362                              ShAmt);
1363     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
1364 
1365     SDValue Ops[2] = { Lo, Hi };
1366     return DAG.getMergeValues(Ops, dl);
1367   }
1368   else {
1369 
1370     // {dHi, dLo} = {aHi, aLo} << Amt
1371     // - if (Amt>=size) then
1372     //      dLo = aLo << Amt (all 0)
1373     //      dLo = aLo << (Amt-size)
1374     //   else
1375     //      dLo = aLo << Amt
1376     //      dHi = (aHi << Amt) | (aLo >> (size-Amt))
1377 
1378     SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
1379                                    DAG.getConstant(VTBits, MVT::i32), ShAmt);
1380     SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
1381     SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
1382                                      DAG.getConstant(VTBits, MVT::i32));
1383     SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
1384     SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
1385     SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
1386 
1387     SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
1388                                DAG.getConstant(VTBits, MVT::i32), ISD::SETGE);
1389     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
1390     SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
1391 
1392     SDValue Ops[2] = { Lo, Hi };
1393     return DAG.getMergeValues(Ops, dl);
1394   }
1395 }
1396 
1397 SDValue
LowerOperation(SDValue Op,SelectionDAG & DAG) const1398 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1399   switch (Op.getOpcode()) {
1400   case ISD::RETURNADDR:
1401     return SDValue();
1402   case ISD::FRAMEADDR:
1403     return SDValue();
1404   case ISD::GlobalAddress:
1405     return LowerGlobalAddress(Op, DAG);
1406   case ISD::INTRINSIC_W_CHAIN:
1407     return Op;
1408   case ISD::BUILD_VECTOR:
1409   case ISD::EXTRACT_SUBVECTOR:
1410     return Op;
1411   case ISD::CONCAT_VECTORS:
1412     return LowerCONCAT_VECTORS(Op, DAG);
1413   case ISD::STORE:
1414     return LowerSTORE(Op, DAG);
1415   case ISD::LOAD:
1416     return LowerLOAD(Op, DAG);
1417   case ISD::SHL_PARTS:
1418     return LowerShiftLeftParts(Op, DAG);
1419   case ISD::SRA_PARTS:
1420   case ISD::SRL_PARTS:
1421     return LowerShiftRightParts(Op, DAG);
1422   default:
1423     llvm_unreachable("Custom lowering not defined for operation");
1424   }
1425 }
1426 
LowerLOAD(SDValue Op,SelectionDAG & DAG) const1427 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1428   if (Op.getValueType() == MVT::i1)
1429     return LowerLOADi1(Op, DAG);
1430   else
1431     return SDValue();
1432 }
1433 
1434 // v = ld i1* addr
1435 //   =>
1436 // v1 = ld i8* addr (-> i16)
1437 // v = trunc i16 to i1
LowerLOADi1(SDValue Op,SelectionDAG & DAG) const1438 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1439   SDNode *Node = Op.getNode();
1440   LoadSDNode *LD = cast<LoadSDNode>(Node);
1441   SDLoc dl(Node);
1442   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1443   assert(Node->getValueType(0) == MVT::i1 &&
1444          "Custom lowering for i1 load only");
1445   SDValue newLD =
1446       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1447                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1448                   LD->isInvariant(), LD->getAlignment());
1449   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1450   // The legalizer (the caller) is expecting two values from the legalized
1451   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1452   // in LegalizeDAG.cpp which also uses MergeValues.
1453   SDValue Ops[] = { result, LD->getChain() };
1454   return DAG.getMergeValues(Ops, dl);
1455 }
1456 
LowerSTORE(SDValue Op,SelectionDAG & DAG) const1457 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1458   EVT ValVT = Op.getOperand(1).getValueType();
1459   if (ValVT == MVT::i1)
1460     return LowerSTOREi1(Op, DAG);
1461   else if (ValVT.isVector())
1462     return LowerSTOREVector(Op, DAG);
1463   else
1464     return SDValue();
1465 }
1466 
1467 SDValue
LowerSTOREVector(SDValue Op,SelectionDAG & DAG) const1468 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1469   SDNode *N = Op.getNode();
1470   SDValue Val = N->getOperand(1);
1471   SDLoc DL(N);
1472   EVT ValVT = Val.getValueType();
1473 
1474   if (ValVT.isVector()) {
1475     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1476     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1477     // but I'm leaving that as a TODO for now.
1478     if (!ValVT.isSimple())
1479       return SDValue();
1480     switch (ValVT.getSimpleVT().SimpleTy) {
1481     default:
1482       return SDValue();
1483     case MVT::v2i8:
1484     case MVT::v2i16:
1485     case MVT::v2i32:
1486     case MVT::v2i64:
1487     case MVT::v2f32:
1488     case MVT::v2f64:
1489     case MVT::v4i8:
1490     case MVT::v4i16:
1491     case MVT::v4i32:
1492     case MVT::v4f32:
1493       // This is a "native" vector type
1494       break;
1495     }
1496 
1497     unsigned Opcode = 0;
1498     EVT EltVT = ValVT.getVectorElementType();
1499     unsigned NumElts = ValVT.getVectorNumElements();
1500 
1501     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1502     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1503     // stored type to i16 and propagate the "real" type as the memory type.
1504     bool NeedExt = false;
1505     if (EltVT.getSizeInBits() < 16)
1506       NeedExt = true;
1507 
1508     switch (NumElts) {
1509     default:
1510       return SDValue();
1511     case 2:
1512       Opcode = NVPTXISD::StoreV2;
1513       break;
1514     case 4: {
1515       Opcode = NVPTXISD::StoreV4;
1516       break;
1517     }
1518     }
1519 
1520     SmallVector<SDValue, 8> Ops;
1521 
1522     // First is the chain
1523     Ops.push_back(N->getOperand(0));
1524 
1525     // Then the split values
1526     for (unsigned i = 0; i < NumElts; ++i) {
1527       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1528                                    DAG.getIntPtrConstant(i));
1529       if (NeedExt)
1530         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
1531       Ops.push_back(ExtVal);
1532     }
1533 
1534     // Then any remaining arguments
1535     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1536       Ops.push_back(N->getOperand(i));
1537     }
1538 
1539     MemSDNode *MemSD = cast<MemSDNode>(N);
1540 
1541     SDValue NewSt = DAG.getMemIntrinsicNode(
1542         Opcode, DL, DAG.getVTList(MVT::Other), Ops,
1543         MemSD->getMemoryVT(), MemSD->getMemOperand());
1544 
1545     //return DCI.CombineTo(N, NewSt, true);
1546     return NewSt;
1547   }
1548 
1549   return SDValue();
1550 }
1551 
1552 // st i1 v, addr
1553 //    =>
1554 // v1 = zxt v to i16
1555 // st.u8 i16, addr
LowerSTOREi1(SDValue Op,SelectionDAG & DAG) const1556 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1557   SDNode *Node = Op.getNode();
1558   SDLoc dl(Node);
1559   StoreSDNode *ST = cast<StoreSDNode>(Node);
1560   SDValue Tmp1 = ST->getChain();
1561   SDValue Tmp2 = ST->getBasePtr();
1562   SDValue Tmp3 = ST->getValue();
1563   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1564   unsigned Alignment = ST->getAlignment();
1565   bool isVolatile = ST->isVolatile();
1566   bool isNonTemporal = ST->isNonTemporal();
1567   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1568   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1569                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1570                                      isVolatile, Alignment);
1571   return Result;
1572 }
1573 
getExtSymb(SelectionDAG & DAG,const char * inname,int idx,EVT v) const1574 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1575                                         int idx, EVT v) const {
1576   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1577   std::stringstream suffix;
1578   suffix << idx;
1579   *name += suffix.str();
1580   return DAG.getTargetExternalSymbol(name->c_str(), v);
1581 }
1582 
1583 SDValue
getParamSymbol(SelectionDAG & DAG,int idx,EVT v) const1584 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1585   std::string ParamSym;
1586   raw_string_ostream ParamStr(ParamSym);
1587 
1588   ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx;
1589   ParamStr.flush();
1590 
1591   std::string *SavedStr =
1592     nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str());
1593   return DAG.getTargetExternalSymbol(SavedStr->c_str(), v);
1594 }
1595 
getParamHelpSymbol(SelectionDAG & DAG,int idx)1596 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1597   return getExtSymb(DAG, ".HLPPARAM", idx);
1598 }
1599 
1600 // Check to see if the kernel argument is image*_t or sampler_t
1601 
isImageOrSamplerVal(const Value * arg,const Module * context)1602 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1603   static const char *const specialTypes[] = { "struct._image2d_t",
1604                                               "struct._image3d_t",
1605                                               "struct._sampler_t" };
1606 
1607   const Type *Ty = arg->getType();
1608   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1609 
1610   if (!PTy)
1611     return false;
1612 
1613   if (!context)
1614     return false;
1615 
1616   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1617   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1618 
1619   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1620     if (TypeName == specialTypes[i])
1621       return true;
1622 
1623   return false;
1624 }
1625 
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,SDLoc dl,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const1626 SDValue NVPTXTargetLowering::LowerFormalArguments(
1627     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1628     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1629     SmallVectorImpl<SDValue> &InVals) const {
1630   MachineFunction &MF = DAG.getMachineFunction();
1631   const DataLayout *TD = getDataLayout();
1632 
1633   const Function *F = MF.getFunction();
1634   const AttributeSet &PAL = F->getAttributes();
1635   const TargetLowering *TLI = DAG.getTarget().getTargetLowering();
1636 
1637   SDValue Root = DAG.getRoot();
1638   std::vector<SDValue> OutChains;
1639 
1640   bool isKernel = llvm::isKernelFunction(*F);
1641   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1642   assert(isABI && "Non-ABI compilation is not supported");
1643   if (!isABI)
1644     return Chain;
1645 
1646   std::vector<Type *> argTypes;
1647   std::vector<const Argument *> theArgs;
1648   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1649        I != E; ++I) {
1650     theArgs.push_back(I);
1651     argTypes.push_back(I->getType());
1652   }
1653   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1654   // Ins.size() will be larger
1655   //   * if there is an aggregate argument with multiple fields (each field
1656   //     showing up separately in Ins)
1657   //   * if there is a vector argument with more than typical vector-length
1658   //     elements (generally if more than 4) where each vector element is
1659   //     individually present in Ins.
1660   // So a different index should be used for indexing into Ins.
1661   // See similar issue in LowerCall.
1662   unsigned InsIdx = 0;
1663 
1664   int idx = 0;
1665   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1666     Type *Ty = argTypes[i];
1667 
1668     // If the kernel argument is image*_t or sampler_t, convert it to
1669     // a i32 constant holding the parameter position. This can later
1670     // matched in the AsmPrinter to output the correct mangled name.
1671     if (isImageOrSamplerVal(
1672             theArgs[i],
1673             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1674                                      : nullptr))) {
1675       assert(isKernel && "Only kernels can have image/sampler params");
1676       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1677       continue;
1678     }
1679 
1680     if (theArgs[i]->use_empty()) {
1681       // argument is dead
1682       if (Ty->isAggregateType()) {
1683         SmallVector<EVT, 16> vtparts;
1684 
1685         ComputePTXValueVTs(*this, Ty, vtparts);
1686         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1687         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1688              ++parti) {
1689           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1690           ++InsIdx;
1691         }
1692         if (vtparts.size() > 0)
1693           --InsIdx;
1694         continue;
1695       }
1696       if (Ty->isVectorTy()) {
1697         EVT ObjectVT = getValueType(Ty);
1698         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1699         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1700           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1701           ++InsIdx;
1702         }
1703         if (NumRegs > 0)
1704           --InsIdx;
1705         continue;
1706       }
1707       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1708       continue;
1709     }
1710 
1711     // In the following cases, assign a node order of "idx+1"
1712     // to newly created nodes. The SDNodes for params have to
1713     // appear in the same order as their order of appearance
1714     // in the original function. "idx+1" holds that order.
1715     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1716       if (Ty->isAggregateType()) {
1717         SmallVector<EVT, 16> vtparts;
1718         SmallVector<uint64_t, 16> offsets;
1719 
1720         // NOTE: Here, we lose the ability to issue vector loads for vectors
1721         // that are a part of a struct.  This should be investigated in the
1722         // future.
1723         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1724         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1725         bool aggregateIsPacked = false;
1726         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1727           aggregateIsPacked = STy->isPacked();
1728 
1729         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1730         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1731              ++parti) {
1732           EVT partVT = vtparts[parti];
1733           Value *srcValue = Constant::getNullValue(
1734               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1735                                llvm::ADDRESS_SPACE_PARAM));
1736           SDValue srcAddr =
1737               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1738                           DAG.getConstant(offsets[parti], getPointerTy()));
1739           unsigned partAlign =
1740               aggregateIsPacked ? 1
1741                                 : TD->getABITypeAlignment(
1742                                       partVT.getTypeForEVT(F->getContext()));
1743           SDValue p;
1744           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
1745             ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1746                                      ISD::SEXTLOAD : ISD::ZEXTLOAD;
1747             p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
1748                                MachinePointerInfo(srcValue), partVT, false,
1749                                false, partAlign);
1750           } else {
1751             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1752                             MachinePointerInfo(srcValue), false, false, false,
1753                             partAlign);
1754           }
1755           if (p.getNode())
1756             p.getNode()->setIROrder(idx + 1);
1757           InVals.push_back(p);
1758           ++InsIdx;
1759         }
1760         if (vtparts.size() > 0)
1761           --InsIdx;
1762         continue;
1763       }
1764       if (Ty->isVectorTy()) {
1765         EVT ObjectVT = getValueType(Ty);
1766         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1767         unsigned NumElts = ObjectVT.getVectorNumElements();
1768         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1769                "Vector was not scalarized");
1770         unsigned Ofst = 0;
1771         EVT EltVT = ObjectVT.getVectorElementType();
1772 
1773         // V1 load
1774         // f32 = load ...
1775         if (NumElts == 1) {
1776           // We only have one element, so just directly load it
1777           Value *SrcValue = Constant::getNullValue(PointerType::get(
1778               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1779           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1780                                         DAG.getConstant(Ofst, getPointerTy()));
1781           SDValue P = DAG.getLoad(
1782               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1783               false, true,
1784               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1785           if (P.getNode())
1786             P.getNode()->setIROrder(idx + 1);
1787 
1788           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1789             P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
1790           InVals.push_back(P);
1791           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1792           ++InsIdx;
1793         } else if (NumElts == 2) {
1794           // V2 load
1795           // f32,f32 = load ...
1796           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1797           Value *SrcValue = Constant::getNullValue(PointerType::get(
1798               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1799           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1800                                         DAG.getConstant(Ofst, getPointerTy()));
1801           SDValue P = DAG.getLoad(
1802               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1803               false, true,
1804               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1805           if (P.getNode())
1806             P.getNode()->setIROrder(idx + 1);
1807 
1808           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1809                                      DAG.getIntPtrConstant(0));
1810           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1811                                      DAG.getIntPtrConstant(1));
1812 
1813           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1814             Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1815             Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1816           }
1817 
1818           InVals.push_back(Elt0);
1819           InVals.push_back(Elt1);
1820           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1821           InsIdx += 2;
1822         } else {
1823           // V4 loads
1824           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1825           // the
1826           // vector will be expanded to a power of 2 elements, so we know we can
1827           // always round up to the next multiple of 4 when creating the vector
1828           // loads.
1829           // e.g.  4 elem => 1 ld.v4
1830           //       6 elem => 2 ld.v4
1831           //       8 elem => 2 ld.v4
1832           //      11 elem => 3 ld.v4
1833           unsigned VecSize = 4;
1834           if (EltVT.getSizeInBits() == 64) {
1835             VecSize = 2;
1836           }
1837           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1838           for (unsigned i = 0; i < NumElts; i += VecSize) {
1839             Value *SrcValue = Constant::getNullValue(
1840                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1841                                  llvm::ADDRESS_SPACE_PARAM));
1842             SDValue SrcAddr =
1843                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1844                             DAG.getConstant(Ofst, getPointerTy()));
1845             SDValue P = DAG.getLoad(
1846                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1847                 false, true,
1848                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1849             if (P.getNode())
1850               P.getNode()->setIROrder(idx + 1);
1851 
1852             for (unsigned j = 0; j < VecSize; ++j) {
1853               if (i + j >= NumElts)
1854                 break;
1855               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1856                                         DAG.getIntPtrConstant(j));
1857               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1858                 Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
1859               InVals.push_back(Elt);
1860             }
1861             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1862           }
1863           InsIdx += NumElts;
1864         }
1865 
1866         if (NumElts > 0)
1867           --InsIdx;
1868         continue;
1869       }
1870       // A plain scalar.
1871       EVT ObjectVT = getValueType(Ty);
1872       // If ABI, load from the param symbol
1873       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1874       Value *srcValue = Constant::getNullValue(PointerType::get(
1875           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1876       SDValue p;
1877        if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
1878         ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1879                                        ISD::SEXTLOAD : ISD::ZEXTLOAD;
1880         p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
1881                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1882         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1883       } else {
1884         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1885                         MachinePointerInfo(srcValue), false, false, false,
1886         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1887       }
1888       if (p.getNode())
1889         p.getNode()->setIROrder(idx + 1);
1890       InVals.push_back(p);
1891       continue;
1892     }
1893 
1894     // Param has ByVal attribute
1895     // Return MoveParam(param symbol).
1896     // Ideally, the param symbol can be returned directly,
1897     // but when SDNode builder decides to use it in a CopyToReg(),
1898     // machine instruction fails because TargetExternalSymbol
1899     // (not lowered) is target dependent, and CopyToReg assumes
1900     // the source is lowered.
1901     EVT ObjectVT = getValueType(Ty);
1902     assert(ObjectVT == Ins[InsIdx].VT &&
1903            "Ins type did not match function type");
1904     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1905     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1906     if (p.getNode())
1907       p.getNode()->setIROrder(idx + 1);
1908     if (isKernel)
1909       InVals.push_back(p);
1910     else {
1911       SDValue p2 = DAG.getNode(
1912           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1913           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1914       InVals.push_back(p2);
1915     }
1916   }
1917 
1918   // Clang will check explicit VarArg and issue error if any. However, Clang
1919   // will let code with
1920   // implicit var arg like f() pass. See bug 617733.
1921   // We treat this case as if the arg list is empty.
1922   // if (F.isVarArg()) {
1923   // assert(0 && "VarArg not supported yet!");
1924   //}
1925 
1926   if (!OutChains.empty())
1927     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
1928 
1929   return Chain;
1930 }
1931 
1932 
1933 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,SDLoc dl,SelectionDAG & DAG) const1934 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1935                                  bool isVarArg,
1936                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1937                                  const SmallVectorImpl<SDValue> &OutVals,
1938                                  SDLoc dl, SelectionDAG &DAG) const {
1939   MachineFunction &MF = DAG.getMachineFunction();
1940   const Function *F = MF.getFunction();
1941   Type *RetTy = F->getReturnType();
1942   const DataLayout *TD = getDataLayout();
1943 
1944   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1945   assert(isABI && "Non-ABI compilation is not supported");
1946   if (!isABI)
1947     return Chain;
1948 
1949   if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1950     // If we have a vector type, the OutVals array will be the scalarized
1951     // components and we have combine them into 1 or more vector stores.
1952     unsigned NumElts = VTy->getNumElements();
1953     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1954 
1955     // const_cast can be removed in later LLVM versions
1956     EVT EltVT = getValueType(RetTy).getVectorElementType();
1957     bool NeedExtend = false;
1958     if (EltVT.getSizeInBits() < 16)
1959       NeedExtend = true;
1960 
1961     // V1 store
1962     if (NumElts == 1) {
1963       SDValue StoreVal = OutVals[0];
1964       // We only have one element, so just directly store it
1965       if (NeedExtend)
1966         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1967       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1968       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1969                                       DAG.getVTList(MVT::Other), Ops,
1970                                       EltVT, MachinePointerInfo());
1971 
1972     } else if (NumElts == 2) {
1973       // V2 store
1974       SDValue StoreVal0 = OutVals[0];
1975       SDValue StoreVal1 = OutVals[1];
1976 
1977       if (NeedExtend) {
1978         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1979         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1980       }
1981 
1982       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1983                         StoreVal1 };
1984       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1985                                       DAG.getVTList(MVT::Other), Ops,
1986                                       EltVT, MachinePointerInfo());
1987     } else {
1988       // V4 stores
1989       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1990       // vector will be expanded to a power of 2 elements, so we know we can
1991       // always round up to the next multiple of 4 when creating the vector
1992       // stores.
1993       // e.g.  4 elem => 1 st.v4
1994       //       6 elem => 2 st.v4
1995       //       8 elem => 2 st.v4
1996       //      11 elem => 3 st.v4
1997 
1998       unsigned VecSize = 4;
1999       if (OutVals[0].getValueType().getSizeInBits() == 64)
2000         VecSize = 2;
2001 
2002       unsigned Offset = 0;
2003 
2004       EVT VecVT =
2005           EVT::getVectorVT(F->getContext(), EltVT, VecSize);
2006       unsigned PerStoreOffset =
2007           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
2008 
2009       for (unsigned i = 0; i < NumElts; i += VecSize) {
2010         // Get values
2011         SDValue StoreVal;
2012         SmallVector<SDValue, 8> Ops;
2013         Ops.push_back(Chain);
2014         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
2015         unsigned Opc = NVPTXISD::StoreRetvalV2;
2016         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
2017 
2018         StoreVal = OutVals[i];
2019         if (NeedExtend)
2020           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
2021         Ops.push_back(StoreVal);
2022 
2023         if (i + 1 < NumElts) {
2024           StoreVal = OutVals[i + 1];
2025           if (NeedExtend)
2026             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
2027         } else {
2028           StoreVal = DAG.getUNDEF(ExtendedVT);
2029         }
2030         Ops.push_back(StoreVal);
2031 
2032         if (VecSize == 4) {
2033           Opc = NVPTXISD::StoreRetvalV4;
2034           if (i + 2 < NumElts) {
2035             StoreVal = OutVals[i + 2];
2036             if (NeedExtend)
2037               StoreVal =
2038                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
2039           } else {
2040             StoreVal = DAG.getUNDEF(ExtendedVT);
2041           }
2042           Ops.push_back(StoreVal);
2043 
2044           if (i + 3 < NumElts) {
2045             StoreVal = OutVals[i + 3];
2046             if (NeedExtend)
2047               StoreVal =
2048                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
2049           } else {
2050             StoreVal = DAG.getUNDEF(ExtendedVT);
2051           }
2052           Ops.push_back(StoreVal);
2053         }
2054 
2055         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
2056         Chain =
2057             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), Ops,
2058                                     EltVT, MachinePointerInfo());
2059         Offset += PerStoreOffset;
2060       }
2061     }
2062   } else {
2063     SmallVector<EVT, 16> ValVTs;
2064     SmallVector<uint64_t, 16> Offsets;
2065     ComputePTXValueVTs(*this, RetTy, ValVTs, &Offsets, 0);
2066     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
2067 
2068     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
2069       SDValue theVal = OutVals[i];
2070       EVT TheValType = theVal.getValueType();
2071       unsigned numElems = 1;
2072       if (TheValType.isVector())
2073         numElems = TheValType.getVectorNumElements();
2074       for (unsigned j = 0, je = numElems; j != je; ++j) {
2075         SDValue TmpVal = theVal;
2076         if (TheValType.isVector())
2077           TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
2078                                TheValType.getVectorElementType(), TmpVal,
2079                                DAG.getIntPtrConstant(j));
2080         EVT TheStoreType = ValVTs[i];
2081         if (RetTy->isIntegerTy() &&
2082             TD->getTypeAllocSizeInBits(RetTy) < 32) {
2083           // The following zero-extension is for integer types only, and
2084           // specifically not for aggregates.
2085           TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
2086           TheStoreType = MVT::i32;
2087         }
2088         else if (TmpVal.getValueType().getSizeInBits() < 16)
2089           TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
2090 
2091         SDValue Ops[] = {
2092           Chain,
2093           DAG.getConstant(Offsets[i], MVT::i32),
2094           TmpVal };
2095         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
2096                                         DAG.getVTList(MVT::Other), Ops,
2097                                         TheStoreType,
2098                                         MachinePointerInfo());
2099       }
2100     }
2101   }
2102 
2103   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
2104 }
2105 
2106 
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const2107 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
2108     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
2109     SelectionDAG &DAG) const {
2110   if (Constraint.length() > 1)
2111     return;
2112   else
2113     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
2114 }
2115 
2116 // NVPTX suuport vector of legal types of any length in Intrinsics because the
2117 // NVPTX specific type legalizer
2118 // will legalize them to the PTX supported length.
isTypeSupportedInIntrinsic(MVT VT) const2119 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
2120   if (isTypeLegal(VT))
2121     return true;
2122   if (VT.isVector()) {
2123     MVT eVT = VT.getVectorElementType();
2124     if (isTypeLegal(eVT))
2125       return true;
2126   }
2127   return false;
2128 }
2129 
getOpcForTextureInstr(unsigned Intrinsic)2130 static unsigned getOpcForTextureInstr(unsigned Intrinsic) {
2131   switch (Intrinsic) {
2132   default:
2133     return 0;
2134 
2135   case Intrinsic::nvvm_tex_1d_v4f32_i32:
2136     return NVPTXISD::Tex1DFloatI32;
2137   case Intrinsic::nvvm_tex_1d_v4f32_f32:
2138     return NVPTXISD::Tex1DFloatFloat;
2139   case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
2140     return NVPTXISD::Tex1DFloatFloatLevel;
2141   case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
2142     return NVPTXISD::Tex1DFloatFloatGrad;
2143   case Intrinsic::nvvm_tex_1d_v4i32_i32:
2144     return NVPTXISD::Tex1DI32I32;
2145   case Intrinsic::nvvm_tex_1d_v4i32_f32:
2146     return NVPTXISD::Tex1DI32Float;
2147   case Intrinsic::nvvm_tex_1d_level_v4i32_f32:
2148     return NVPTXISD::Tex1DI32FloatLevel;
2149   case Intrinsic::nvvm_tex_1d_grad_v4i32_f32:
2150     return NVPTXISD::Tex1DI32FloatGrad;
2151 
2152   case Intrinsic::nvvm_tex_1d_array_v4f32_i32:
2153     return NVPTXISD::Tex1DArrayFloatI32;
2154   case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
2155     return NVPTXISD::Tex1DArrayFloatFloat;
2156   case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
2157     return NVPTXISD::Tex1DArrayFloatFloatLevel;
2158   case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
2159     return NVPTXISD::Tex1DArrayFloatFloatGrad;
2160   case Intrinsic::nvvm_tex_1d_array_v4i32_i32:
2161     return NVPTXISD::Tex1DArrayI32I32;
2162   case Intrinsic::nvvm_tex_1d_array_v4i32_f32:
2163     return NVPTXISD::Tex1DArrayI32Float;
2164   case Intrinsic::nvvm_tex_1d_array_level_v4i32_f32:
2165     return NVPTXISD::Tex1DArrayI32FloatLevel;
2166   case Intrinsic::nvvm_tex_1d_array_grad_v4i32_f32:
2167     return NVPTXISD::Tex1DArrayI32FloatGrad;
2168 
2169   case Intrinsic::nvvm_tex_2d_v4f32_i32:
2170     return NVPTXISD::Tex2DFloatI32;
2171   case Intrinsic::nvvm_tex_2d_v4f32_f32:
2172     return NVPTXISD::Tex2DFloatFloat;
2173   case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
2174     return NVPTXISD::Tex2DFloatFloatLevel;
2175   case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
2176     return NVPTXISD::Tex2DFloatFloatGrad;
2177   case Intrinsic::nvvm_tex_2d_v4i32_i32:
2178     return NVPTXISD::Tex2DI32I32;
2179   case Intrinsic::nvvm_tex_2d_v4i32_f32:
2180     return NVPTXISD::Tex2DI32Float;
2181   case Intrinsic::nvvm_tex_2d_level_v4i32_f32:
2182     return NVPTXISD::Tex2DI32FloatLevel;
2183   case Intrinsic::nvvm_tex_2d_grad_v4i32_f32:
2184     return NVPTXISD::Tex2DI32FloatGrad;
2185 
2186   case Intrinsic::nvvm_tex_2d_array_v4f32_i32:
2187     return NVPTXISD::Tex2DArrayFloatI32;
2188   case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
2189     return NVPTXISD::Tex2DArrayFloatFloat;
2190   case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
2191     return NVPTXISD::Tex2DArrayFloatFloatLevel;
2192   case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
2193     return NVPTXISD::Tex2DArrayFloatFloatGrad;
2194   case Intrinsic::nvvm_tex_2d_array_v4i32_i32:
2195     return NVPTXISD::Tex2DArrayI32I32;
2196   case Intrinsic::nvvm_tex_2d_array_v4i32_f32:
2197     return NVPTXISD::Tex2DArrayI32Float;
2198   case Intrinsic::nvvm_tex_2d_array_level_v4i32_f32:
2199     return NVPTXISD::Tex2DArrayI32FloatLevel;
2200   case Intrinsic::nvvm_tex_2d_array_grad_v4i32_f32:
2201     return NVPTXISD::Tex2DArrayI32FloatGrad;
2202 
2203   case Intrinsic::nvvm_tex_3d_v4f32_i32:
2204     return NVPTXISD::Tex3DFloatI32;
2205   case Intrinsic::nvvm_tex_3d_v4f32_f32:
2206     return NVPTXISD::Tex3DFloatFloat;
2207   case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
2208     return NVPTXISD::Tex3DFloatFloatLevel;
2209   case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
2210     return NVPTXISD::Tex3DFloatFloatGrad;
2211   case Intrinsic::nvvm_tex_3d_v4i32_i32:
2212     return NVPTXISD::Tex3DI32I32;
2213   case Intrinsic::nvvm_tex_3d_v4i32_f32:
2214     return NVPTXISD::Tex3DI32Float;
2215   case Intrinsic::nvvm_tex_3d_level_v4i32_f32:
2216     return NVPTXISD::Tex3DI32FloatLevel;
2217   case Intrinsic::nvvm_tex_3d_grad_v4i32_f32:
2218     return NVPTXISD::Tex3DI32FloatGrad;
2219   }
2220 }
2221 
getOpcForSurfaceInstr(unsigned Intrinsic)2222 static unsigned getOpcForSurfaceInstr(unsigned Intrinsic) {
2223   switch (Intrinsic) {
2224   default:
2225     return 0;
2226   case Intrinsic::nvvm_suld_1d_i8_trap:
2227     return NVPTXISD::Suld1DI8Trap;
2228   case Intrinsic::nvvm_suld_1d_i16_trap:
2229     return NVPTXISD::Suld1DI16Trap;
2230   case Intrinsic::nvvm_suld_1d_i32_trap:
2231     return NVPTXISD::Suld1DI32Trap;
2232   case Intrinsic::nvvm_suld_1d_v2i8_trap:
2233     return NVPTXISD::Suld1DV2I8Trap;
2234   case Intrinsic::nvvm_suld_1d_v2i16_trap:
2235     return NVPTXISD::Suld1DV2I16Trap;
2236   case Intrinsic::nvvm_suld_1d_v2i32_trap:
2237     return NVPTXISD::Suld1DV2I32Trap;
2238   case Intrinsic::nvvm_suld_1d_v4i8_trap:
2239     return NVPTXISD::Suld1DV4I8Trap;
2240   case Intrinsic::nvvm_suld_1d_v4i16_trap:
2241     return NVPTXISD::Suld1DV4I16Trap;
2242   case Intrinsic::nvvm_suld_1d_v4i32_trap:
2243     return NVPTXISD::Suld1DV4I32Trap;
2244   case Intrinsic::nvvm_suld_1d_array_i8_trap:
2245     return NVPTXISD::Suld1DArrayI8Trap;
2246   case Intrinsic::nvvm_suld_1d_array_i16_trap:
2247     return NVPTXISD::Suld1DArrayI16Trap;
2248   case Intrinsic::nvvm_suld_1d_array_i32_trap:
2249     return NVPTXISD::Suld1DArrayI32Trap;
2250   case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
2251     return NVPTXISD::Suld1DArrayV2I8Trap;
2252   case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
2253     return NVPTXISD::Suld1DArrayV2I16Trap;
2254   case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
2255     return NVPTXISD::Suld1DArrayV2I32Trap;
2256   case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
2257     return NVPTXISD::Suld1DArrayV4I8Trap;
2258   case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
2259     return NVPTXISD::Suld1DArrayV4I16Trap;
2260   case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
2261     return NVPTXISD::Suld1DArrayV4I32Trap;
2262   case Intrinsic::nvvm_suld_2d_i8_trap:
2263     return NVPTXISD::Suld2DI8Trap;
2264   case Intrinsic::nvvm_suld_2d_i16_trap:
2265     return NVPTXISD::Suld2DI16Trap;
2266   case Intrinsic::nvvm_suld_2d_i32_trap:
2267     return NVPTXISD::Suld2DI32Trap;
2268   case Intrinsic::nvvm_suld_2d_v2i8_trap:
2269     return NVPTXISD::Suld2DV2I8Trap;
2270   case Intrinsic::nvvm_suld_2d_v2i16_trap:
2271     return NVPTXISD::Suld2DV2I16Trap;
2272   case Intrinsic::nvvm_suld_2d_v2i32_trap:
2273     return NVPTXISD::Suld2DV2I32Trap;
2274   case Intrinsic::nvvm_suld_2d_v4i8_trap:
2275     return NVPTXISD::Suld2DV4I8Trap;
2276   case Intrinsic::nvvm_suld_2d_v4i16_trap:
2277     return NVPTXISD::Suld2DV4I16Trap;
2278   case Intrinsic::nvvm_suld_2d_v4i32_trap:
2279     return NVPTXISD::Suld2DV4I32Trap;
2280   case Intrinsic::nvvm_suld_2d_array_i8_trap:
2281     return NVPTXISD::Suld2DArrayI8Trap;
2282   case Intrinsic::nvvm_suld_2d_array_i16_trap:
2283     return NVPTXISD::Suld2DArrayI16Trap;
2284   case Intrinsic::nvvm_suld_2d_array_i32_trap:
2285     return NVPTXISD::Suld2DArrayI32Trap;
2286   case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
2287     return NVPTXISD::Suld2DArrayV2I8Trap;
2288   case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
2289     return NVPTXISD::Suld2DArrayV2I16Trap;
2290   case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
2291     return NVPTXISD::Suld2DArrayV2I32Trap;
2292   case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
2293     return NVPTXISD::Suld2DArrayV4I8Trap;
2294   case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
2295     return NVPTXISD::Suld2DArrayV4I16Trap;
2296   case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
2297     return NVPTXISD::Suld2DArrayV4I32Trap;
2298   case Intrinsic::nvvm_suld_3d_i8_trap:
2299     return NVPTXISD::Suld3DI8Trap;
2300   case Intrinsic::nvvm_suld_3d_i16_trap:
2301     return NVPTXISD::Suld3DI16Trap;
2302   case Intrinsic::nvvm_suld_3d_i32_trap:
2303     return NVPTXISD::Suld3DI32Trap;
2304   case Intrinsic::nvvm_suld_3d_v2i8_trap:
2305     return NVPTXISD::Suld3DV2I8Trap;
2306   case Intrinsic::nvvm_suld_3d_v2i16_trap:
2307     return NVPTXISD::Suld3DV2I16Trap;
2308   case Intrinsic::nvvm_suld_3d_v2i32_trap:
2309     return NVPTXISD::Suld3DV2I32Trap;
2310   case Intrinsic::nvvm_suld_3d_v4i8_trap:
2311     return NVPTXISD::Suld3DV4I8Trap;
2312   case Intrinsic::nvvm_suld_3d_v4i16_trap:
2313     return NVPTXISD::Suld3DV4I16Trap;
2314   case Intrinsic::nvvm_suld_3d_v4i32_trap:
2315     return NVPTXISD::Suld3DV4I32Trap;
2316   }
2317 }
2318 
2319 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
2320 // TgtMemIntrinsic
2321 // because we need the information that is only available in the "Value" type
2322 // of destination
2323 // pointer. In particular, the address space information.
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,unsigned Intrinsic) const2324 bool NVPTXTargetLowering::getTgtMemIntrinsic(
2325     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
2326   switch (Intrinsic) {
2327   default:
2328     return false;
2329 
2330   case Intrinsic::nvvm_atomic_load_add_f32:
2331     Info.opc = ISD::INTRINSIC_W_CHAIN;
2332     Info.memVT = MVT::f32;
2333     Info.ptrVal = I.getArgOperand(0);
2334     Info.offset = 0;
2335     Info.vol = 0;
2336     Info.readMem = true;
2337     Info.writeMem = true;
2338     Info.align = 0;
2339     return true;
2340 
2341   case Intrinsic::nvvm_atomic_load_inc_32:
2342   case Intrinsic::nvvm_atomic_load_dec_32:
2343     Info.opc = ISD::INTRINSIC_W_CHAIN;
2344     Info.memVT = MVT::i32;
2345     Info.ptrVal = I.getArgOperand(0);
2346     Info.offset = 0;
2347     Info.vol = 0;
2348     Info.readMem = true;
2349     Info.writeMem = true;
2350     Info.align = 0;
2351     return true;
2352 
2353   case Intrinsic::nvvm_ldu_global_i:
2354   case Intrinsic::nvvm_ldu_global_f:
2355   case Intrinsic::nvvm_ldu_global_p: {
2356 
2357     Info.opc = ISD::INTRINSIC_W_CHAIN;
2358     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2359       Info.memVT = getValueType(I.getType());
2360     else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
2361       Info.memVT = getPointerTy();
2362     else
2363       Info.memVT = getValueType(I.getType());
2364     Info.ptrVal = I.getArgOperand(0);
2365     Info.offset = 0;
2366     Info.vol = 0;
2367     Info.readMem = true;
2368     Info.writeMem = false;
2369 
2370     // alignment is available as metadata.
2371     // Grab it and set the alignment.
2372     assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
2373     MDNode *AlignMD = I.getMetadata("align");
2374     assert(AlignMD && "Must have a non-null MDNode");
2375     assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
2376     Value *Align = AlignMD->getOperand(0);
2377     int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
2378     Info.align = Alignment;
2379 
2380     return true;
2381   }
2382   case Intrinsic::nvvm_ldg_global_i:
2383   case Intrinsic::nvvm_ldg_global_f:
2384   case Intrinsic::nvvm_ldg_global_p: {
2385 
2386     Info.opc = ISD::INTRINSIC_W_CHAIN;
2387     if (Intrinsic == Intrinsic::nvvm_ldg_global_i)
2388       Info.memVT = getValueType(I.getType());
2389     else if(Intrinsic == Intrinsic::nvvm_ldg_global_p)
2390       Info.memVT = getPointerTy();
2391     else
2392       Info.memVT = getValueType(I.getType());
2393     Info.ptrVal = I.getArgOperand(0);
2394     Info.offset = 0;
2395     Info.vol = 0;
2396     Info.readMem = true;
2397     Info.writeMem = false;
2398 
2399     // alignment is available as metadata.
2400     // Grab it and set the alignment.
2401     assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
2402     MDNode *AlignMD = I.getMetadata("align");
2403     assert(AlignMD && "Must have a non-null MDNode");
2404     assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
2405     Value *Align = AlignMD->getOperand(0);
2406     int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
2407     Info.align = Alignment;
2408 
2409     return true;
2410   }
2411 
2412   case Intrinsic::nvvm_tex_1d_v4f32_i32:
2413   case Intrinsic::nvvm_tex_1d_v4f32_f32:
2414   case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
2415   case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
2416   case Intrinsic::nvvm_tex_1d_array_v4f32_i32:
2417   case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
2418   case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
2419   case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
2420   case Intrinsic::nvvm_tex_2d_v4f32_i32:
2421   case Intrinsic::nvvm_tex_2d_v4f32_f32:
2422   case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
2423   case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
2424   case Intrinsic::nvvm_tex_2d_array_v4f32_i32:
2425   case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
2426   case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
2427   case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
2428   case Intrinsic::nvvm_tex_3d_v4f32_i32:
2429   case Intrinsic::nvvm_tex_3d_v4f32_f32:
2430   case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
2431   case Intrinsic::nvvm_tex_3d_grad_v4f32_f32: {
2432     Info.opc = getOpcForTextureInstr(Intrinsic);
2433     Info.memVT = MVT::f32;
2434     Info.ptrVal = nullptr;
2435     Info.offset = 0;
2436     Info.vol = 0;
2437     Info.readMem = true;
2438     Info.writeMem = false;
2439     Info.align = 16;
2440     return true;
2441   }
2442   case Intrinsic::nvvm_tex_1d_v4i32_i32:
2443   case Intrinsic::nvvm_tex_1d_v4i32_f32:
2444   case Intrinsic::nvvm_tex_1d_level_v4i32_f32:
2445   case Intrinsic::nvvm_tex_1d_grad_v4i32_f32:
2446   case Intrinsic::nvvm_tex_1d_array_v4i32_i32:
2447   case Intrinsic::nvvm_tex_1d_array_v4i32_f32:
2448   case Intrinsic::nvvm_tex_1d_array_level_v4i32_f32:
2449   case Intrinsic::nvvm_tex_1d_array_grad_v4i32_f32:
2450   case Intrinsic::nvvm_tex_2d_v4i32_i32:
2451   case Intrinsic::nvvm_tex_2d_v4i32_f32:
2452   case Intrinsic::nvvm_tex_2d_level_v4i32_f32:
2453   case Intrinsic::nvvm_tex_2d_grad_v4i32_f32:
2454   case Intrinsic::nvvm_tex_2d_array_v4i32_i32:
2455   case Intrinsic::nvvm_tex_2d_array_v4i32_f32:
2456   case Intrinsic::nvvm_tex_2d_array_level_v4i32_f32:
2457   case Intrinsic::nvvm_tex_2d_array_grad_v4i32_f32:
2458   case Intrinsic::nvvm_tex_3d_v4i32_i32:
2459   case Intrinsic::nvvm_tex_3d_v4i32_f32:
2460   case Intrinsic::nvvm_tex_3d_level_v4i32_f32:
2461   case Intrinsic::nvvm_tex_3d_grad_v4i32_f32: {
2462     Info.opc = getOpcForTextureInstr(Intrinsic);
2463     Info.memVT = MVT::i32;
2464     Info.ptrVal = nullptr;
2465     Info.offset = 0;
2466     Info.vol = 0;
2467     Info.readMem = true;
2468     Info.writeMem = false;
2469     Info.align = 16;
2470     return true;
2471   }
2472   case Intrinsic::nvvm_suld_1d_i8_trap:
2473   case Intrinsic::nvvm_suld_1d_v2i8_trap:
2474   case Intrinsic::nvvm_suld_1d_v4i8_trap:
2475   case Intrinsic::nvvm_suld_1d_array_i8_trap:
2476   case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
2477   case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
2478   case Intrinsic::nvvm_suld_2d_i8_trap:
2479   case Intrinsic::nvvm_suld_2d_v2i8_trap:
2480   case Intrinsic::nvvm_suld_2d_v4i8_trap:
2481   case Intrinsic::nvvm_suld_2d_array_i8_trap:
2482   case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
2483   case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
2484   case Intrinsic::nvvm_suld_3d_i8_trap:
2485   case Intrinsic::nvvm_suld_3d_v2i8_trap:
2486   case Intrinsic::nvvm_suld_3d_v4i8_trap: {
2487     Info.opc = getOpcForSurfaceInstr(Intrinsic);
2488     Info.memVT = MVT::i8;
2489     Info.ptrVal = nullptr;
2490     Info.offset = 0;
2491     Info.vol = 0;
2492     Info.readMem = true;
2493     Info.writeMem = false;
2494     Info.align = 16;
2495     return true;
2496   }
2497   case Intrinsic::nvvm_suld_1d_i16_trap:
2498   case Intrinsic::nvvm_suld_1d_v2i16_trap:
2499   case Intrinsic::nvvm_suld_1d_v4i16_trap:
2500   case Intrinsic::nvvm_suld_1d_array_i16_trap:
2501   case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
2502   case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
2503   case Intrinsic::nvvm_suld_2d_i16_trap:
2504   case Intrinsic::nvvm_suld_2d_v2i16_trap:
2505   case Intrinsic::nvvm_suld_2d_v4i16_trap:
2506   case Intrinsic::nvvm_suld_2d_array_i16_trap:
2507   case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
2508   case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
2509   case Intrinsic::nvvm_suld_3d_i16_trap:
2510   case Intrinsic::nvvm_suld_3d_v2i16_trap:
2511   case Intrinsic::nvvm_suld_3d_v4i16_trap: {
2512     Info.opc = getOpcForSurfaceInstr(Intrinsic);
2513     Info.memVT = MVT::i16;
2514     Info.ptrVal = nullptr;
2515     Info.offset = 0;
2516     Info.vol = 0;
2517     Info.readMem = true;
2518     Info.writeMem = false;
2519     Info.align = 16;
2520     return true;
2521   }
2522   case Intrinsic::nvvm_suld_1d_i32_trap:
2523   case Intrinsic::nvvm_suld_1d_v2i32_trap:
2524   case Intrinsic::nvvm_suld_1d_v4i32_trap:
2525   case Intrinsic::nvvm_suld_1d_array_i32_trap:
2526   case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
2527   case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
2528   case Intrinsic::nvvm_suld_2d_i32_trap:
2529   case Intrinsic::nvvm_suld_2d_v2i32_trap:
2530   case Intrinsic::nvvm_suld_2d_v4i32_trap:
2531   case Intrinsic::nvvm_suld_2d_array_i32_trap:
2532   case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
2533   case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
2534   case Intrinsic::nvvm_suld_3d_i32_trap:
2535   case Intrinsic::nvvm_suld_3d_v2i32_trap:
2536   case Intrinsic::nvvm_suld_3d_v4i32_trap: {
2537     Info.opc = getOpcForSurfaceInstr(Intrinsic);
2538     Info.memVT = MVT::i32;
2539     Info.ptrVal = nullptr;
2540     Info.offset = 0;
2541     Info.vol = 0;
2542     Info.readMem = true;
2543     Info.writeMem = false;
2544     Info.align = 16;
2545     return true;
2546   }
2547 
2548   }
2549   return false;
2550 }
2551 
2552 /// isLegalAddressingMode - Return true if the addressing mode represented
2553 /// by AM is legal for this target, for a load/store of the specified type.
2554 /// Used to guide target specific optimizations, like loop strength reduction
2555 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2556 /// (CodeGenPrepare.cpp)
isLegalAddressingMode(const AddrMode & AM,Type * Ty) const2557 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2558                                                 Type *Ty) const {
2559 
2560   // AddrMode - This represents an addressing mode of:
2561   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2562   //
2563   // The legal address modes are
2564   // - [avar]
2565   // - [areg]
2566   // - [areg+immoff]
2567   // - [immAddr]
2568 
2569   if (AM.BaseGV) {
2570     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2571       return false;
2572     return true;
2573   }
2574 
2575   switch (AM.Scale) {
2576   case 0: // "r", "r+i" or "i" is allowed
2577     break;
2578   case 1:
2579     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2580       return false;
2581     // Otherwise we have r+i.
2582     break;
2583   default:
2584     // No scale > 1 is allowed
2585     return false;
2586   }
2587   return true;
2588 }
2589 
2590 //===----------------------------------------------------------------------===//
2591 //                         NVPTX Inline Assembly Support
2592 //===----------------------------------------------------------------------===//
2593 
2594 /// getConstraintType - Given a constraint letter, return the type of
2595 /// constraint it is for this target.
2596 NVPTXTargetLowering::ConstraintType
getConstraintType(const std::string & Constraint) const2597 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2598   if (Constraint.size() == 1) {
2599     switch (Constraint[0]) {
2600     default:
2601       break;
2602     case 'b':
2603     case 'r':
2604     case 'h':
2605     case 'c':
2606     case 'l':
2607     case 'f':
2608     case 'd':
2609     case '0':
2610     case 'N':
2611       return C_RegisterClass;
2612     }
2613   }
2614   return TargetLowering::getConstraintType(Constraint);
2615 }
2616 
2617 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const std::string & Constraint,MVT VT) const2618 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2619                                                   MVT VT) const {
2620   if (Constraint.size() == 1) {
2621     switch (Constraint[0]) {
2622     case 'b':
2623       return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
2624     case 'c':
2625       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2626     case 'h':
2627       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2628     case 'r':
2629       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2630     case 'l':
2631     case 'N':
2632       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2633     case 'f':
2634       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2635     case 'd':
2636       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2637     }
2638   }
2639   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2640 }
2641 
2642 /// getFunctionAlignment - Return the Log2 alignment of this function.
getFunctionAlignment(const Function *) const2643 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2644   return 4;
2645 }
2646 
2647 //===----------------------------------------------------------------------===//
2648 //                         NVPTX DAG Combining
2649 //===----------------------------------------------------------------------===//
2650 
2651 extern unsigned FMAContractLevel;
2652 
2653 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
2654 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
2655 /// called with the default operands, and if that fails, with commuted
2656 /// operands.
PerformADDCombineWithOperands(SDNode * N,SDValue N0,SDValue N1,TargetLowering::DAGCombinerInfo & DCI,const NVPTXSubtarget & Subtarget,CodeGenOpt::Level OptLevel)2657 static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
2658                                            TargetLowering::DAGCombinerInfo &DCI,
2659                                              const NVPTXSubtarget &Subtarget,
2660                                              CodeGenOpt::Level OptLevel) {
2661   SelectionDAG  &DAG = DCI.DAG;
2662   // Skip non-integer, non-scalar case
2663   EVT VT=N0.getValueType();
2664   if (VT.isVector())
2665     return SDValue();
2666 
2667   // fold (add (mul a, b), c) -> (mad a, b, c)
2668   //
2669   if (N0.getOpcode() == ISD::MUL) {
2670     assert (VT.isInteger());
2671     // For integer:
2672     // Since integer multiply-add costs the same as integer multiply
2673     // but is more costly than integer add, do the fusion only when
2674     // the mul is only used in the add.
2675     if (OptLevel==CodeGenOpt::None || VT != MVT::i32 ||
2676         !N0.getNode()->hasOneUse())
2677       return SDValue();
2678 
2679     // Do the folding
2680     return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
2681                        N0.getOperand(0), N0.getOperand(1), N1);
2682   }
2683   else if (N0.getOpcode() == ISD::FMUL) {
2684     if (VT == MVT::f32 || VT == MVT::f64) {
2685       if (FMAContractLevel == 0)
2686         return SDValue();
2687 
2688       // For floating point:
2689       // Do the fusion only when the mul has less than 5 uses and all
2690       // are add.
2691       // The heuristic is that if a use is not an add, then that use
2692       // cannot be fused into fma, therefore mul is still needed anyway.
2693       // If there are more than 4 uses, even if they are all add, fusing
2694       // them will increase register pressue.
2695       //
2696       int numUses = 0;
2697       int nonAddCount = 0;
2698       for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
2699            UE = N0.getNode()->use_end();
2700            UI != UE; ++UI) {
2701         numUses++;
2702         SDNode *User = *UI;
2703         if (User->getOpcode() != ISD::FADD)
2704           ++nonAddCount;
2705       }
2706       if (numUses >= 5)
2707         return SDValue();
2708       if (nonAddCount) {
2709         int orderNo = N->getIROrder();
2710         int orderNo2 = N0.getNode()->getIROrder();
2711         // simple heuristics here for considering potential register
2712         // pressure, the logics here is that the differnce are used
2713         // to measure the distance between def and use, the longer distance
2714         // more likely cause register pressure.
2715         if (orderNo - orderNo2 < 500)
2716           return SDValue();
2717 
2718         // Now, check if at least one of the FMUL's operands is live beyond the node N,
2719         // which guarantees that the FMA will not increase register pressure at node N.
2720         bool opIsLive = false;
2721         const SDNode *left = N0.getOperand(0).getNode();
2722         const SDNode *right = N0.getOperand(1).getNode();
2723 
2724         if (dyn_cast<ConstantSDNode>(left) || dyn_cast<ConstantSDNode>(right))
2725           opIsLive = true;
2726 
2727         if (!opIsLive)
2728           for (SDNode::use_iterator UI = left->use_begin(), UE = left->use_end(); UI != UE; ++UI) {
2729             SDNode *User = *UI;
2730             int orderNo3 = User->getIROrder();
2731             if (orderNo3 > orderNo) {
2732               opIsLive = true;
2733               break;
2734             }
2735           }
2736 
2737         if (!opIsLive)
2738           for (SDNode::use_iterator UI = right->use_begin(), UE = right->use_end(); UI != UE; ++UI) {
2739             SDNode *User = *UI;
2740             int orderNo3 = User->getIROrder();
2741             if (orderNo3 > orderNo) {
2742               opIsLive = true;
2743               break;
2744             }
2745           }
2746 
2747         if (!opIsLive)
2748           return SDValue();
2749       }
2750 
2751       return DAG.getNode(ISD::FMA, SDLoc(N), VT,
2752                          N0.getOperand(0), N0.getOperand(1), N1);
2753     }
2754   }
2755 
2756   return SDValue();
2757 }
2758 
2759 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
2760 ///
PerformADDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const NVPTXSubtarget & Subtarget,CodeGenOpt::Level OptLevel)2761 static SDValue PerformADDCombine(SDNode *N,
2762                                  TargetLowering::DAGCombinerInfo &DCI,
2763                                  const NVPTXSubtarget &Subtarget,
2764                                  CodeGenOpt::Level OptLevel) {
2765   SDValue N0 = N->getOperand(0);
2766   SDValue N1 = N->getOperand(1);
2767 
2768   // First try with the default operand order.
2769   SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget,
2770                                                  OptLevel);
2771   if (Result.getNode())
2772     return Result;
2773 
2774   // If that didn't work, try again with the operands commuted.
2775   return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
2776 }
2777 
PerformANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2778 static SDValue PerformANDCombine(SDNode *N,
2779                                  TargetLowering::DAGCombinerInfo &DCI) {
2780   // The type legalizer turns a vector load of i8 values into a zextload to i16
2781   // registers, optionally ANY_EXTENDs it (if target type is integer),
2782   // and ANDs off the high 8 bits. Since we turn this load into a
2783   // target-specific DAG node, the DAG combiner fails to eliminate these AND
2784   // nodes. Do that here.
2785   SDValue Val = N->getOperand(0);
2786   SDValue Mask = N->getOperand(1);
2787 
2788   if (isa<ConstantSDNode>(Val)) {
2789     std::swap(Val, Mask);
2790   }
2791 
2792   SDValue AExt;
2793   // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
2794   if (Val.getOpcode() == ISD::ANY_EXTEND) {
2795     AExt = Val;
2796     Val = Val->getOperand(0);
2797   }
2798 
2799   if (Val->isMachineOpcode() && Val->getMachineOpcode() == NVPTX::IMOV16rr) {
2800     Val = Val->getOperand(0);
2801   }
2802 
2803   if (Val->getOpcode() == NVPTXISD::LoadV2 ||
2804       Val->getOpcode() == NVPTXISD::LoadV4) {
2805     ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
2806     if (!MaskCnst) {
2807       // Not an AND with a constant
2808       return SDValue();
2809     }
2810 
2811     uint64_t MaskVal = MaskCnst->getZExtValue();
2812     if (MaskVal != 0xff) {
2813       // Not an AND that chops off top 8 bits
2814       return SDValue();
2815     }
2816 
2817     MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
2818     if (!Mem) {
2819       // Not a MemSDNode?!?
2820       return SDValue();
2821     }
2822 
2823     EVT MemVT = Mem->getMemoryVT();
2824     if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
2825       // We only handle the i8 case
2826       return SDValue();
2827     }
2828 
2829     unsigned ExtType =
2830       cast<ConstantSDNode>(Val->getOperand(Val->getNumOperands()-1))->
2831         getZExtValue();
2832     if (ExtType == ISD::SEXTLOAD) {
2833       // If for some reason the load is a sextload, the and is needed to zero
2834       // out the high 8 bits
2835       return SDValue();
2836     }
2837 
2838     bool AddTo = false;
2839     if (AExt.getNode() != 0) {
2840       // Re-insert the ext as a zext.
2841       Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
2842                             AExt.getValueType(), Val);
2843       AddTo = true;
2844     }
2845 
2846     // If we get here, the AND is unnecessary.  Just replace it with the load
2847     DCI.CombineTo(N, Val, AddTo);
2848   }
2849 
2850   return SDValue();
2851 }
2852 
2853 enum OperandSignedness {
2854   Signed = 0,
2855   Unsigned,
2856   Unknown
2857 };
2858 
2859 /// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
2860 /// that can be demoted to \p OptSize bits without loss of information. The
2861 /// signedness of the operand, if determinable, is placed in \p S.
IsMulWideOperandDemotable(SDValue Op,unsigned OptSize,OperandSignedness & S)2862 static bool IsMulWideOperandDemotable(SDValue Op,
2863                                       unsigned OptSize,
2864                                       OperandSignedness &S) {
2865   S = Unknown;
2866 
2867   if (Op.getOpcode() == ISD::SIGN_EXTEND ||
2868       Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2869     EVT OrigVT = Op.getOperand(0).getValueType();
2870     if (OrigVT.getSizeInBits() == OptSize) {
2871       S = Signed;
2872       return true;
2873     }
2874   } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
2875     EVT OrigVT = Op.getOperand(0).getValueType();
2876     if (OrigVT.getSizeInBits() == OptSize) {
2877       S = Unsigned;
2878       return true;
2879     }
2880   }
2881 
2882   return false;
2883 }
2884 
2885 /// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
2886 /// be demoted to \p OptSize bits without loss of information. If the operands
2887 /// contain a constant, it should appear as the RHS operand. The signedness of
2888 /// the operands is placed in \p IsSigned.
AreMulWideOperandsDemotable(SDValue LHS,SDValue RHS,unsigned OptSize,bool & IsSigned)2889 static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
2890                                         unsigned OptSize,
2891                                         bool &IsSigned) {
2892 
2893   OperandSignedness LHSSign;
2894 
2895   // The LHS operand must be a demotable op
2896   if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
2897     return false;
2898 
2899   // We should have been able to determine the signedness from the LHS
2900   if (LHSSign == Unknown)
2901     return false;
2902 
2903   IsSigned = (LHSSign == Signed);
2904 
2905   // The RHS can be a demotable op or a constant
2906   if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) {
2907     APInt Val = CI->getAPIntValue();
2908     if (LHSSign == Unsigned) {
2909       if (Val.isIntN(OptSize)) {
2910         return true;
2911       }
2912       return false;
2913     } else {
2914       if (Val.isSignedIntN(OptSize)) {
2915         return true;
2916       }
2917       return false;
2918     }
2919   } else {
2920     OperandSignedness RHSSign;
2921     if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
2922       return false;
2923 
2924     if (LHSSign != RHSSign)
2925       return false;
2926 
2927     return true;
2928   }
2929 }
2930 
2931 /// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
2932 /// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
2933 /// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
2934 /// amount.
TryMULWIDECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2935 static SDValue TryMULWIDECombine(SDNode *N,
2936                                  TargetLowering::DAGCombinerInfo &DCI) {
2937   EVT MulType = N->getValueType(0);
2938   if (MulType != MVT::i32 && MulType != MVT::i64) {
2939     return SDValue();
2940   }
2941 
2942   unsigned OptSize = MulType.getSizeInBits() >> 1;
2943   SDValue LHS = N->getOperand(0);
2944   SDValue RHS = N->getOperand(1);
2945 
2946   // Canonicalize the multiply so the constant (if any) is on the right
2947   if (N->getOpcode() == ISD::MUL) {
2948     if (isa<ConstantSDNode>(LHS)) {
2949       std::swap(LHS, RHS);
2950     }
2951   }
2952 
2953   // If we have a SHL, determine the actual multiply amount
2954   if (N->getOpcode() == ISD::SHL) {
2955     ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS);
2956     if (!ShlRHS) {
2957       return SDValue();
2958     }
2959 
2960     APInt ShiftAmt = ShlRHS->getAPIntValue();
2961     unsigned BitWidth = MulType.getSizeInBits();
2962     if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
2963       APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
2964       RHS = DCI.DAG.getConstant(MulVal, MulType);
2965     } else {
2966       return SDValue();
2967     }
2968   }
2969 
2970   bool Signed;
2971   // Verify that our operands are demotable
2972   if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
2973     return SDValue();
2974   }
2975 
2976   EVT DemotedVT;
2977   if (MulType == MVT::i32) {
2978     DemotedVT = MVT::i16;
2979   } else {
2980     DemotedVT = MVT::i32;
2981   }
2982 
2983   // Truncate the operands to the correct size. Note that these are just for
2984   // type consistency and will (likely) be eliminated in later phases.
2985   SDValue TruncLHS =
2986     DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, LHS);
2987   SDValue TruncRHS =
2988     DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, RHS);
2989 
2990   unsigned Opc;
2991   if (Signed) {
2992     Opc = NVPTXISD::MUL_WIDE_SIGNED;
2993   } else {
2994     Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
2995   }
2996 
2997   return DCI.DAG.getNode(Opc, SDLoc(N), MulType, TruncLHS, TruncRHS);
2998 }
2999 
3000 /// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
PerformMULCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,CodeGenOpt::Level OptLevel)3001 static SDValue PerformMULCombine(SDNode *N,
3002                                  TargetLowering::DAGCombinerInfo &DCI,
3003                                  CodeGenOpt::Level OptLevel) {
3004   if (OptLevel > 0) {
3005     // Try mul.wide combining at OptLevel > 0
3006     SDValue Ret = TryMULWIDECombine(N, DCI);
3007     if (Ret.getNode())
3008       return Ret;
3009   }
3010 
3011   return SDValue();
3012 }
3013 
3014 /// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
PerformSHLCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,CodeGenOpt::Level OptLevel)3015 static SDValue PerformSHLCombine(SDNode *N,
3016                                  TargetLowering::DAGCombinerInfo &DCI,
3017                                  CodeGenOpt::Level OptLevel) {
3018   if (OptLevel > 0) {
3019     // Try mul.wide combining at OptLevel > 0
3020     SDValue Ret = TryMULWIDECombine(N, DCI);
3021     if (Ret.getNode())
3022       return Ret;
3023   }
3024 
3025   return SDValue();
3026 }
3027 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const3028 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
3029                                                DAGCombinerInfo &DCI) const {
3030   // FIXME: Get this from the DAG somehow
3031   CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
3032   switch (N->getOpcode()) {
3033     default: break;
3034     case ISD::ADD:
3035     case ISD::FADD:
3036       return PerformADDCombine(N, DCI, nvptxSubtarget, OptLevel);
3037     case ISD::MUL:
3038       return PerformMULCombine(N, DCI, OptLevel);
3039     case ISD::SHL:
3040       return PerformSHLCombine(N, DCI, OptLevel);
3041     case ISD::AND:
3042       return PerformANDCombine(N, DCI);
3043   }
3044   return SDValue();
3045 }
3046 
3047 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
ReplaceLoadVector(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)3048 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
3049                               SmallVectorImpl<SDValue> &Results) {
3050   EVT ResVT = N->getValueType(0);
3051   SDLoc DL(N);
3052 
3053   assert(ResVT.isVector() && "Vector load must have vector type");
3054 
3055   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
3056   // legal.  We can (and should) split that into 2 loads of <2 x double> here
3057   // but I'm leaving that as a TODO for now.
3058   assert(ResVT.isSimple() && "Can only handle simple types");
3059   switch (ResVT.getSimpleVT().SimpleTy) {
3060   default:
3061     return;
3062   case MVT::v2i8:
3063   case MVT::v2i16:
3064   case MVT::v2i32:
3065   case MVT::v2i64:
3066   case MVT::v2f32:
3067   case MVT::v2f64:
3068   case MVT::v4i8:
3069   case MVT::v4i16:
3070   case MVT::v4i32:
3071   case MVT::v4f32:
3072     // This is a "native" vector type
3073     break;
3074   }
3075 
3076   EVT EltVT = ResVT.getVectorElementType();
3077   unsigned NumElts = ResVT.getVectorNumElements();
3078 
3079   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3080   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
3081   // loaded type to i16 and propagate the "real" type as the memory type.
3082   bool NeedTrunc = false;
3083   if (EltVT.getSizeInBits() < 16) {
3084     EltVT = MVT::i16;
3085     NeedTrunc = true;
3086   }
3087 
3088   unsigned Opcode = 0;
3089   SDVTList LdResVTs;
3090 
3091   switch (NumElts) {
3092   default:
3093     return;
3094   case 2:
3095     Opcode = NVPTXISD::LoadV2;
3096     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
3097     break;
3098   case 4: {
3099     Opcode = NVPTXISD::LoadV4;
3100     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
3101     LdResVTs = DAG.getVTList(ListVTs);
3102     break;
3103   }
3104   }
3105 
3106   SmallVector<SDValue, 8> OtherOps;
3107 
3108   // Copy regular operands
3109   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
3110     OtherOps.push_back(N->getOperand(i));
3111 
3112   LoadSDNode *LD = cast<LoadSDNode>(N);
3113 
3114   // The select routine does not have access to the LoadSDNode instance, so
3115   // pass along the extension information
3116   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
3117 
3118   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
3119                                           LD->getMemoryVT(),
3120                                           LD->getMemOperand());
3121 
3122   SmallVector<SDValue, 4> ScalarRes;
3123 
3124   for (unsigned i = 0; i < NumElts; ++i) {
3125     SDValue Res = NewLD.getValue(i);
3126     if (NeedTrunc)
3127       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
3128     ScalarRes.push_back(Res);
3129   }
3130 
3131   SDValue LoadChain = NewLD.getValue(NumElts);
3132 
3133   SDValue BuildVec = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
3134 
3135   Results.push_back(BuildVec);
3136   Results.push_back(LoadChain);
3137 }
3138 
ReplaceINTRINSIC_W_CHAIN(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)3139 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
3140                                      SmallVectorImpl<SDValue> &Results) {
3141   SDValue Chain = N->getOperand(0);
3142   SDValue Intrin = N->getOperand(1);
3143   SDLoc DL(N);
3144 
3145   // Get the intrinsic ID
3146   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
3147   switch (IntrinNo) {
3148   default:
3149     return;
3150   case Intrinsic::nvvm_ldg_global_i:
3151   case Intrinsic::nvvm_ldg_global_f:
3152   case Intrinsic::nvvm_ldg_global_p:
3153   case Intrinsic::nvvm_ldu_global_i:
3154   case Intrinsic::nvvm_ldu_global_f:
3155   case Intrinsic::nvvm_ldu_global_p: {
3156     EVT ResVT = N->getValueType(0);
3157 
3158     if (ResVT.isVector()) {
3159       // Vector LDG/LDU
3160 
3161       unsigned NumElts = ResVT.getVectorNumElements();
3162       EVT EltVT = ResVT.getVectorElementType();
3163 
3164       // Since LDU/LDG are target nodes, we cannot rely on DAG type
3165       // legalization.
3166       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
3167       // loaded type to i16 and propagate the "real" type as the memory type.
3168       bool NeedTrunc = false;
3169       if (EltVT.getSizeInBits() < 16) {
3170         EltVT = MVT::i16;
3171         NeedTrunc = true;
3172       }
3173 
3174       unsigned Opcode = 0;
3175       SDVTList LdResVTs;
3176 
3177       switch (NumElts) {
3178       default:
3179         return;
3180       case 2:
3181         switch (IntrinNo) {
3182         default:
3183           return;
3184         case Intrinsic::nvvm_ldg_global_i:
3185         case Intrinsic::nvvm_ldg_global_f:
3186         case Intrinsic::nvvm_ldg_global_p:
3187           Opcode = NVPTXISD::LDGV2;
3188           break;
3189         case Intrinsic::nvvm_ldu_global_i:
3190         case Intrinsic::nvvm_ldu_global_f:
3191         case Intrinsic::nvvm_ldu_global_p:
3192           Opcode = NVPTXISD::LDUV2;
3193           break;
3194         }
3195         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
3196         break;
3197       case 4: {
3198         switch (IntrinNo) {
3199         default:
3200           return;
3201         case Intrinsic::nvvm_ldg_global_i:
3202         case Intrinsic::nvvm_ldg_global_f:
3203         case Intrinsic::nvvm_ldg_global_p:
3204           Opcode = NVPTXISD::LDGV4;
3205           break;
3206         case Intrinsic::nvvm_ldu_global_i:
3207         case Intrinsic::nvvm_ldu_global_f:
3208         case Intrinsic::nvvm_ldu_global_p:
3209           Opcode = NVPTXISD::LDUV4;
3210           break;
3211         }
3212         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
3213         LdResVTs = DAG.getVTList(ListVTs);
3214         break;
3215       }
3216       }
3217 
3218       SmallVector<SDValue, 8> OtherOps;
3219 
3220       // Copy regular operands
3221 
3222       OtherOps.push_back(Chain); // Chain
3223                                  // Skip operand 1 (intrinsic ID)
3224       // Others
3225       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
3226         OtherOps.push_back(N->getOperand(i));
3227 
3228       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
3229 
3230       SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
3231                                               MemSD->getMemoryVT(),
3232                                               MemSD->getMemOperand());
3233 
3234       SmallVector<SDValue, 4> ScalarRes;
3235 
3236       for (unsigned i = 0; i < NumElts; ++i) {
3237         SDValue Res = NewLD.getValue(i);
3238         if (NeedTrunc)
3239           Res =
3240               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
3241         ScalarRes.push_back(Res);
3242       }
3243 
3244       SDValue LoadChain = NewLD.getValue(NumElts);
3245 
3246       SDValue BuildVec =
3247           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
3248 
3249       Results.push_back(BuildVec);
3250       Results.push_back(LoadChain);
3251     } else {
3252       // i8 LDG/LDU
3253       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
3254              "Custom handling of non-i8 ldu/ldg?");
3255 
3256       // Just copy all operands as-is
3257       SmallVector<SDValue, 4> Ops;
3258       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
3259         Ops.push_back(N->getOperand(i));
3260 
3261       // Force output to i16
3262       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
3263 
3264       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
3265 
3266       // We make sure the memory type is i8, which will be used during isel
3267       // to select the proper instruction.
3268       SDValue NewLD =
3269           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, Ops,
3270                                   MVT::i8, MemSD->getMemOperand());
3271 
3272       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
3273                                     NewLD.getValue(0)));
3274       Results.push_back(NewLD.getValue(1));
3275     }
3276   }
3277   }
3278 }
3279 
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const3280 void NVPTXTargetLowering::ReplaceNodeResults(
3281     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
3282   switch (N->getOpcode()) {
3283   default:
3284     report_fatal_error("Unhandled custom legalization");
3285   case ISD::LOAD:
3286     ReplaceLoadVector(N, DAG, Results);
3287     return;
3288   case ISD::INTRINSIC_W_CHAIN:
3289     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
3290     return;
3291   }
3292 }
3293 
3294 // Pin NVPTXSection's and NVPTXTargetObjectFile's vtables to this file.
anchor()3295 void NVPTXSection::anchor() {}
3296 
~NVPTXTargetObjectFile()3297 NVPTXTargetObjectFile::~NVPTXTargetObjectFile() {
3298   delete TextSection;
3299   delete DataSection;
3300   delete BSSSection;
3301   delete ReadOnlySection;
3302 
3303   delete StaticCtorSection;
3304   delete StaticDtorSection;
3305   delete LSDASection;
3306   delete EHFrameSection;
3307   delete DwarfAbbrevSection;
3308   delete DwarfInfoSection;
3309   delete DwarfLineSection;
3310   delete DwarfFrameSection;
3311   delete DwarfPubTypesSection;
3312   delete DwarfDebugInlineSection;
3313   delete DwarfStrSection;
3314   delete DwarfLocSection;
3315   delete DwarfARangesSection;
3316   delete DwarfRangesSection;
3317   delete DwarfMacroInfoSection;
3318 }
3319