1// RUN: mlir-opt -allow-unregistered-dialect --convert-gpu-to-nvvm --split-input-file %s | FileCheck --check-prefix=NVVM %s 2// RUN: mlir-opt -allow-unregistered-dialect --convert-gpu-to-rocdl --split-input-file %s | FileCheck --check-prefix=ROCDL %s 3 4gpu.module @kernel { 5 // NVVM-LABEL: llvm.func @private 6 gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, 5>) { 7 // Allocate private memory inside the function. 8 // NVVM: %[[size:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64 9 // NVVM: %[[raw:.*]] = llvm.alloca %[[size]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float> 10 11 // ROCDL: %[[size:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64 12 // ROCDL: %[[raw:.*]] = llvm.alloca %[[size]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5> 13 14 // Populate the memref descriptor. 15 // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)> 16 // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 17 // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 18 // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 19 // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 20 // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 21 // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 22 // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 23 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] 24 25 // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 5>, ptr<float, 5>, i64, array<1 x i64>, array<1 x i64>)> 26 // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 27 // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 28 // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 29 // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 30 // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 31 // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 32 // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 33 // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] 34 35 // "Store" lowering should work just as any other memref, only check that 36 // we emit some core instructions. 37 // NVVM: llvm.extractvalue %[[descr6:.*]] 38 // NVVM: llvm.getelementptr 39 // NVVM: llvm.store 40 41 // ROCDL: llvm.extractvalue %[[descr6:.*]] 42 // ROCDL: llvm.getelementptr 43 // ROCDL: llvm.store 44 %c0 = constant 0 : index 45 store %arg0, %arg1[%c0] : memref<4xf32, 5> 46 47 "terminator"() : () -> () 48 } 49} 50 51// ----- 52 53gpu.module @kernel { 54 // Workgroup buffers are allocated as globals. 55 // NVVM: llvm.mlir.global internal @[[$buffer:.*]]() 56 // NVVM-SAME: addr_space = 3 57 // NVVM-SAME: !llvm.array<4 x float> 58 59 // ROCDL: llvm.mlir.global internal @[[$buffer:.*]]() 60 // ROCDL-SAME: addr_space = 3 61 // ROCDL-SAME: !llvm.array<4 x float> 62 63 // NVVM-LABEL: llvm.func @workgroup 64 // NVVM-SAME: { 65 66 // ROCDL-LABEL: llvm.func @workgroup 67 // ROCDL-SAME: { 68 gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, 3>) { 69 // Get the address of the first element in the global array. 70 // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 71 // NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<4 x float>, 3> 72 // NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]] 73 // NVVM-SAME: !llvm.ptr<float, 3> 74 75 // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 76 // ROCDL: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<4 x float>, 3> 77 // ROCDL: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]] 78 // ROCDL-SAME: !llvm.ptr<float, 3> 79 80 // Populate the memref descriptor. 81 // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<1 x i64>, array<1 x i64>)> 82 // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 83 // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 84 // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 85 // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 86 // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 87 // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 88 // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 89 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] 90 91 // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<1 x i64>, array<1 x i64>)> 92 // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 93 // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 94 // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 95 // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 96 // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 97 // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 98 // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 99 // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] 100 101 // "Store" lowering should work just as any other memref, only check that 102 // we emit some core instructions. 103 // NVVM: llvm.extractvalue %[[descr6:.*]] 104 // NVVM: llvm.getelementptr 105 // NVVM: llvm.store 106 107 // ROCDL: llvm.extractvalue %[[descr6:.*]] 108 // ROCDL: llvm.getelementptr 109 // ROCDL: llvm.store 110 %c0 = constant 0 : index 111 store %arg0, %arg1[%c0] : memref<4xf32, 3> 112 113 "terminator"() : () -> () 114 } 115} 116 117// ----- 118 119gpu.module @kernel { 120 // Check that the total size was computed correctly. 121 // NVVM: llvm.mlir.global internal @[[$buffer:.*]]() 122 // NVVM-SAME: addr_space = 3 123 // NVVM-SAME: !llvm.array<48 x float> 124 125 // ROCDL: llvm.mlir.global internal @[[$buffer:.*]]() 126 // ROCDL-SAME: addr_space = 3 127 // ROCDL-SAME: !llvm.array<48 x float> 128 129 // NVVM-LABEL: llvm.func @workgroup3d 130 // ROCDL-LABEL: llvm.func @workgroup3d 131 gpu.func @workgroup3d(%arg0: f32) workgroup(%arg1: memref<4x2x6xf32, 3>) { 132 // Get the address of the first element in the global array. 133 // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 134 // NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<48 x float>, 3> 135 // NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]] 136 // NVVM-SAME: !llvm.ptr<float, 3> 137 138 // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 139 // ROCDL: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<48 x float>, 3> 140 // ROCDL: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]] 141 // ROCDL-SAME: !llvm.ptr<float, 3> 142 143 // Populate the memref descriptor. 144 // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<3 x i64>, array<3 x i64>)> 145 // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 146 // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 147 // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 148 // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 149 // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 150 // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 151 // NVVM: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64 152 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0] 153 // NVVM: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 154 // NVVM: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1] 155 // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 156 // NVVM: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1] 157 // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 158 // NVVM: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2] 159 // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 160 // NVVM: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] 161 162 // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<3 x i64>, array<3 x i64>)> 163 // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] 164 // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] 165 // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 166 // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] 167 // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 168 // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] 169 // ROCDL: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64 170 // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0] 171 // ROCDL: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 172 // ROCDL: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1] 173 // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 174 // ROCDL: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1] 175 // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 176 // ROCDL: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2] 177 // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 178 // ROCDL: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] 179 180 %c0 = constant 0 : index 181 store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3> 182 "terminator"() : () -> () 183 } 184} 185 186// ----- 187 188gpu.module @kernel { 189 // Check that several buffers are defined. 190 // NVVM: llvm.mlir.global internal @[[$buffer1:.*]]() 191 // NVVM-SAME: !llvm.array<1 x float> 192 // NVVM: llvm.mlir.global internal @[[$buffer2:.*]]() 193 // NVVM-SAME: !llvm.array<2 x float> 194 195 // ROCDL: llvm.mlir.global internal @[[$buffer1:.*]]() 196 // ROCDL-SAME: !llvm.array<1 x float> 197 // ROCDL: llvm.mlir.global internal @[[$buffer2:.*]]() 198 // ROCDL-SAME: !llvm.array<2 x float> 199 200 // NVVM-LABEL: llvm.func @multiple 201 // ROCDL-LABEL: llvm.func @multiple 202 gpu.func @multiple(%arg0: f32) 203 workgroup(%arg1: memref<1xf32, 3>, %arg2: memref<2xf32, 3>) 204 private(%arg3: memref<3xf32, 5>, %arg4: memref<4xf32, 5>) { 205 206 // Workgroup buffers. 207 // NVVM: llvm.mlir.addressof @[[$buffer1]] 208 // NVVM: llvm.mlir.addressof @[[$buffer2]] 209 210 // ROCDL: llvm.mlir.addressof @[[$buffer1]] 211 // ROCDL: llvm.mlir.addressof @[[$buffer2]] 212 213 // Private buffers. 214 // NVVM: %[[c3:.*]] = llvm.mlir.constant(3 : i64) 215 // NVVM: llvm.alloca %[[c3]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float> 216 // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : i64) 217 // NVVM: llvm.alloca %[[c4]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float> 218 219 // ROCDL: %[[c3:.*]] = llvm.mlir.constant(3 : i64) 220 // ROCDL: llvm.alloca %[[c3]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5> 221 // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : i64) 222 // ROCDL: llvm.alloca %[[c4]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5> 223 224 %c0 = constant 0 : index 225 store %arg0, %arg1[%c0] : memref<1xf32, 3> 226 store %arg0, %arg2[%c0] : memref<2xf32, 3> 227 store %arg0, %arg3[%c0] : memref<3xf32, 5> 228 store %arg0, %arg4[%c0] : memref<4xf32, 5> 229 "terminator"() : () -> () 230 } 231} 232