• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
3
4gpu.module @test_module {
5  // CHECK-LABEL: func @gpu_index_ops()
6  // CHECK32-LABEL: func @gpu_index_ops()
7  func @gpu_index_ops()
8      -> (index, index, index, index, index, index,
9          index, index, index, index, index, index) {
10    // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
11
12    // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
13    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
14    %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
15    // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
16    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
17    %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
18    // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
19    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
20    %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
21
22    // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
23    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
24    %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
25    // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
26    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
27    %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
28    // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
29    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
30    %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
31
32    // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
33    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
34    %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
35    // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
36    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
37    %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
38    // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
39    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
40    %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
41
42    // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
43    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
44    %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
45    // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
46    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
47    %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
48    // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
49    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
50    %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
51
52    std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
53               %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ
54        : index, index, index, index, index, index,
55          index, index, index, index, index, index
56  }
57}
58
59// -----
60
61gpu.module @test_module {
62  // CHECK-LABEL: func @gpu_index_comp
63  // CHECK32-LABEL: func @gpu_index_comp
64  func @gpu_index_comp(%idx : index) -> index {
65    // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
66    // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
67    %0 = addi %idx, %idx : index
68    // CHECK: llvm.return %{{.*}} : !llvm.i64
69    // CHECK32: llvm.return %{{.*}} : !llvm.i32
70    std.return %0 : index
71  }
72}
73
74// -----
75
76gpu.module @test_module {
77  // CHECK-LABEL: func @gpu_all_reduce_op()
78  gpu.func @gpu_all_reduce_op() {
79    %arg0 = constant 1.0 : f32
80    // TODO: Check full IR expansion once lowering has settled.
81    // CHECK: nvvm.shfl.sync.bfly
82    // CHECK: nvvm.barrier0
83    // CHECK: llvm.fadd
84    %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
85
86    gpu.return
87  }
88}
89
90// -----
91
92gpu.module @test_module {
93  // CHECK-LABEL: func @gpu_all_reduce_region()
94  gpu.func @gpu_all_reduce_region() {
95    %arg0 = constant 1 : i32
96    // TODO: Check full IR expansion once lowering has settled.
97    // CHECK: nvvm.shfl.sync.bfly
98    // CHECK: nvvm.barrier0
99    %result = "gpu.all_reduce"(%arg0) ({
100    ^bb(%lhs : i32, %rhs : i32):
101      %xor = xor %lhs, %rhs : i32
102      "gpu.yield"(%xor) : (i32) -> ()
103    }) : (i32) -> (i32)
104    gpu.return
105  }
106}
107
108// -----
109
110gpu.module @test_module {
111  // CHECK-LABEL: func @gpu_shuffle()
112  func @gpu_shuffle() -> (f32) {
113    // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
114    %arg0 = constant 1.0 : f32
115    // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32
116    %arg1 = constant 4 : i32
117    // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : !llvm.i32
118    %arg2 = constant 23 : i32
119    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : !llvm.i32
120    // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : !llvm.i32
121    // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : !llvm.i32
122    // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : !llvm.i32
123    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm.struct<(float, i1)>
124    // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(float, i1)>
125    // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(float, i1)>
126    %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
127
128    std.return %shfl : f32
129  }
130}
131
132// -----
133
134gpu.module @test_module {
135  // CHECK-LABEL: func @gpu_sync()
136  func @gpu_sync() {
137    // CHECK: nvvm.barrier0
138    gpu.barrier
139    std.return
140  }
141}
142
143// -----
144
145gpu.module @test_module {
146  // CHECK: llvm.func @__nv_fabsf(!llvm.float) -> !llvm.float
147  // CHECK: llvm.func @__nv_fabs(!llvm.double) -> !llvm.double
148  // CHECK-LABEL: func @gpu_fabs
149  func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
150    %result32 = std.absf %arg_f32 : f32
151    // CHECK: llvm.call @__nv_fabsf(%{{.*}}) : (!llvm.float) -> !llvm.float
152    %result64 = std.absf %arg_f64 : f64
153    // CHECK: llvm.call @__nv_fabs(%{{.*}}) : (!llvm.double) -> !llvm.double
154    std.return %result32, %result64 : f32, f64
155  }
156}
157
158// -----
159
160gpu.module @test_module {
161  // CHECK: llvm.func @__nv_ceilf(!llvm.float) -> !llvm.float
162  // CHECK: llvm.func @__nv_ceil(!llvm.double) -> !llvm.double
163  // CHECK-LABEL: func @gpu_ceil
164  func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
165    %result32 = std.ceilf %arg_f32 : f32
166    // CHECK: llvm.call @__nv_ceilf(%{{.*}}) : (!llvm.float) -> !llvm.float
167    %result64 = std.ceilf %arg_f64 : f64
168    // CHECK: llvm.call @__nv_ceil(%{{.*}}) : (!llvm.double) -> !llvm.double
169    std.return %result32, %result64 : f32, f64
170  }
171}
172
173// -----
174
175gpu.module @test_module {
176  // CHECK: llvm.func @__nv_floorf(!llvm.float) -> !llvm.float
177  // CHECK: llvm.func @__nv_floor(!llvm.double) -> !llvm.double
178  // CHECK-LABEL: func @gpu_floor
179  func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
180    %result32 = std.floorf %arg_f32 : f32
181    // CHECK: llvm.call @__nv_floorf(%{{.*}}) : (!llvm.float) -> !llvm.float
182    %result64 = std.floorf %arg_f64 : f64
183    // CHECK: llvm.call @__nv_floor(%{{.*}}) : (!llvm.double) -> !llvm.double
184    std.return %result32, %result64 : f32, f64
185  }
186}
187
188// -----
189
190gpu.module @test_module {
191  // CHECK: llvm.func @__nv_cosf(!llvm.float) -> !llvm.float
192  // CHECK: llvm.func @__nv_cos(!llvm.double) -> !llvm.double
193  // CHECK-LABEL: func @gpu_cos
194  func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
195    %result32 = std.cos %arg_f32 : f32
196    // CHECK: llvm.call @__nv_cosf(%{{.*}}) : (!llvm.float) -> !llvm.float
197    %result64 = std.cos %arg_f64 : f64
198    // CHECK: llvm.call @__nv_cos(%{{.*}}) : (!llvm.double) -> !llvm.double
199    std.return %result32, %result64 : f32, f64
200  }
201}
202
203// -----
204gpu.module @test_module {
205  // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
206  // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
207  // CHECK-LABEL: func @gpu_exp
208  func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
209    %result32 = std.exp %arg_f32 : f32
210    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
211    %result64 = std.exp %arg_f64 : f64
212    // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
213    std.return %result32, %result64 : f32, f64
214  }
215}
216
217// -----
218
219gpu.module @test_module {
220  // CHECK: llvm.func @__nv_logf(!llvm.float) -> !llvm.float
221  // CHECK: llvm.func @__nv_log(!llvm.double) -> !llvm.double
222  // CHECK-LABEL: func @gpu_log
223  func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
224    %result32 = std.log %arg_f32 : f32
225    // CHECK: llvm.call @__nv_logf(%{{.*}}) : (!llvm.float) -> !llvm.float
226    %result64 = std.log %arg_f64 : f64
227    // CHECK: llvm.call @__nv_log(%{{.*}}) : (!llvm.double) -> !llvm.double
228    std.return %result32, %result64 : f32, f64
229  }
230}
231
232// -----
233
234gpu.module @test_module {
235  // CHECK: llvm.func @__nv_log10f(!llvm.float) -> !llvm.float
236  // CHECK: llvm.func @__nv_log10(!llvm.double) -> !llvm.double
237  // CHECK-LABEL: func @gpu_log10
238  func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
239    %result32 = std.log10 %arg_f32 : f32
240    // CHECK: llvm.call @__nv_log10f(%{{.*}}) : (!llvm.float) -> !llvm.float
241    %result64 = std.log10 %arg_f64 : f64
242    // CHECK: llvm.call @__nv_log10(%{{.*}}) : (!llvm.double) -> !llvm.double
243    std.return %result32, %result64 : f32, f64
244  }
245}
246
247// -----
248
249gpu.module @test_module {
250  // CHECK: llvm.func @__nv_log2f(!llvm.float) -> !llvm.float
251  // CHECK: llvm.func @__nv_log2(!llvm.double) -> !llvm.double
252  // CHECK-LABEL: func @gpu_log2
253  func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
254    %result32 = std.log2 %arg_f32 : f32
255    // CHECK: llvm.call @__nv_log2f(%{{.*}}) : (!llvm.float) -> !llvm.float
256    %result64 = std.log2 %arg_f64 : f64
257    // CHECK: llvm.call @__nv_log2(%{{.*}}) : (!llvm.double) -> !llvm.double
258    std.return %result32, %result64 : f32, f64
259  }
260}
261
262// -----
263
264gpu.module @test_module {
265  // CHECK: llvm.func @__nv_sinf(!llvm.float) -> !llvm.float
266  // CHECK: llvm.func @__nv_sin(!llvm.double) -> !llvm.double
267  // CHECK-LABEL: func @gpu_sin
268  func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
269    %result32 = std.sin %arg_f32 : f32
270    // CHECK: llvm.call @__nv_sinf(%{{.*}}) : (!llvm.float) -> !llvm.float
271    %result64 = std.sin %arg_f64 : f64
272    // CHECK: llvm.call @__nv_sin(%{{.*}}) : (!llvm.double) -> !llvm.double
273    std.return %result32, %result64 : f32, f64
274  }
275}
276
277// -----
278
279gpu.module @test_module {
280  // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
281  // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
282  // CHECK-LABEL: func @gpu_tanh
283  func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
284    %result16 = std.tanh %arg_f16 : f16
285    // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
286    // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
287    // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
288    %result32 = std.tanh %arg_f32 : f32
289    // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
290    %result64 = std.tanh %arg_f64 : f64
291    // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
292    std.return %result16, %result32, %result64 : f16, f32, f64
293  }
294}
295
296// -----
297
298gpu.module @test_module {
299  // CHECK: llvm.func @__nv_rsqrtf(!llvm.float) -> !llvm.float
300  // CHECK: llvm.func @__nv_rsqrt(!llvm.double) -> !llvm.double
301  // CHECK-LABEL: func @gpu_rsqrt
302  func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
303      -> (f16, f32, f64) {
304    %result16 = std.rsqrt %arg_f16 : f16
305    // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
306    // CHECK-NEXT: llvm.call @__nv_rsqrtf(%{{.*}}) : (!llvm.float) -> !llvm.float
307    // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
308    %result32 = std.rsqrt %arg_f32 : f32
309    // CHECK: llvm.call @__nv_rsqrtf(%{{.*}}) : (!llvm.float) -> !llvm.float
310    %result64 = std.rsqrt %arg_f64 : f64
311    // CHECK: llvm.call @__nv_rsqrt(%{{.*}}) : (!llvm.double) -> !llvm.double
312    std.return %result16, %result32, %result64 : f16, f32, f64
313  }
314}
315
316// -----
317
318gpu.module @test_module {
319  // CHECK: llvm.func @__nv_sqrtf(!llvm.float) -> !llvm.float
320  // CHECK: llvm.func @__nv_sqrt(!llvm.double) -> !llvm.double
321  // CHECK-LABEL: func @gpu_sqrt
322  func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
323      -> (f16, f32, f64) {
324    %result16 = std.sqrt %arg_f16 : f16
325    // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
326    // CHECK-NEXT: llvm.call @__nv_sqrtf(%{{.*}}) : (!llvm.float) -> !llvm.float
327    // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
328    %result32 = std.sqrt %arg_f32 : f32
329    // CHECK: llvm.call @__nv_sqrtf(%{{.*}}) : (!llvm.float) -> !llvm.float
330    %result64 = std.sqrt %arg_f64 : f64
331    // CHECK: llvm.call @__nv_sqrt(%{{.*}}) : (!llvm.double) -> !llvm.double
332    std.return %result16, %result32, %result64 : f16, f32, f64
333  }
334}
335
336// -----
337
338// Test that we handled properly operation with SymbolTable other than module op
339gpu.module @test_module {
340  "test.symbol_scope"() ({
341  // CHECK: test.symbol_scope
342  // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
343  // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
344  // CHECK-LABEL: func @gpu_exp
345    func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
346      %result32 = std.exp %arg_f32 : f32
347      // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
348      %result64 = std.exp %arg_f64 : f64
349      // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
350      std.return %result32, %result64 : f32, f64
351    }
352    "test.finish" () : () -> ()
353  }) : () -> ()
354}
355
356