• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
20 #include "mlir/Dialect/GPU/GPUDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/FormatVariadic.h"
34 
35 using namespace mlir;
36 
37 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
38 
39 namespace {
40 
41 class GpuToLLVMConversionPass
42     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
43 public:
GpuToLLVMConversionPass(StringRef gpuBinaryAnnotation)44   GpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
45     if (!gpuBinaryAnnotation.empty())
46       this->gpuBinaryAnnotation = gpuBinaryAnnotation.str();
47   }
48 
49   // Run the dialect converter on the module.
50   void runOnOperation() override;
51 };
52 
53 class FunctionCallBuilder {
54 public:
FunctionCallBuilder(StringRef functionName,LLVM::LLVMType returnType,ArrayRef<LLVM::LLVMType> argumentTypes)55   FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
56                       ArrayRef<LLVM::LLVMType> argumentTypes)
57       : functionName(functionName),
58         functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes,
59                                                    /*isVarArg=*/false)) {}
60   LLVM::CallOp create(Location loc, OpBuilder &builder,
61                       ArrayRef<Value> arguments) const;
62 
63 private:
64   StringRef functionName;
65   LLVM::LLVMType functionType;
66 };
67 
68 template <typename OpTy>
69 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
70 public:
ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)71   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
72       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
73 
74 protected:
75   MLIRContext *context = &this->getTypeConverter()->getContext();
76 
77   LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
78   LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
79   LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo();
80   LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
81   LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
82   LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
83   LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy(
84       context, this->getTypeConverter()->getPointerBitwidth(0));
85 
86   FunctionCallBuilder moduleLoadCallBuilder = {
87       "mgpuModuleLoad",
88       llvmPointerType /* void *module */,
89       {llvmPointerType /* void *cubin */}};
90   FunctionCallBuilder moduleUnloadCallBuilder = {
91       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
92   FunctionCallBuilder moduleGetFunctionCallBuilder = {
93       "mgpuModuleGetFunction",
94       llvmPointerType /* void *function */,
95       {
96           llvmPointerType, /* void *module */
97           llvmPointerType  /* char *name   */
98       }};
99   FunctionCallBuilder launchKernelCallBuilder = {
100       "mgpuLaunchKernel",
101       llvmVoidType,
102       {
103           llvmPointerType,        /* void* f */
104           llvmIntPtrType,         /* intptr_t gridXDim */
105           llvmIntPtrType,         /* intptr_t gridyDim */
106           llvmIntPtrType,         /* intptr_t gridZDim */
107           llvmIntPtrType,         /* intptr_t blockXDim */
108           llvmIntPtrType,         /* intptr_t blockYDim */
109           llvmIntPtrType,         /* intptr_t blockZDim */
110           llvmInt32Type,          /* unsigned int sharedMemBytes */
111           llvmPointerType,        /* void *hstream */
112           llvmPointerPointerType, /* void **kernelParams */
113           llvmPointerPointerType  /* void **extra */
114       }};
115   FunctionCallBuilder streamCreateCallBuilder = {
116       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
117   FunctionCallBuilder streamDestroyCallBuilder = {
118       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
119   FunctionCallBuilder streamSynchronizeCallBuilder = {
120       "mgpuStreamSynchronize",
121       llvmVoidType,
122       {llvmPointerType /* void *stream */}};
123   FunctionCallBuilder streamWaitEventCallBuilder = {
124       "mgpuStreamWaitEvent",
125       llvmVoidType,
126       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
127   FunctionCallBuilder eventCreateCallBuilder = {
128       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
129   FunctionCallBuilder eventDestroyCallBuilder = {
130       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
131   FunctionCallBuilder eventSynchronizeCallBuilder = {
132       "mgpuEventSynchronize",
133       llvmVoidType,
134       {llvmPointerType /* void *event */}};
135   FunctionCallBuilder eventRecordCallBuilder = {
136       "mgpuEventRecord",
137       llvmVoidType,
138       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
139   FunctionCallBuilder hostRegisterCallBuilder = {
140       "mgpuMemHostRegisterMemRef",
141       llvmVoidType,
142       {llvmIntPtrType /* intptr_t rank */,
143        llvmPointerType /* void *memrefDesc */,
144        llvmIntPtrType /* intptr_t elementSizeBytes */}};
145   FunctionCallBuilder allocCallBuilder = {
146       "mgpuMemAlloc",
147       llvmPointerType /* void * */,
148       {llvmIntPtrType /* intptr_t sizeBytes */,
149        llvmPointerType /* void *stream */}};
150   FunctionCallBuilder deallocCallBuilder = {
151       "mgpuMemFree",
152       llvmVoidType,
153       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
154 };
155 
156 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
157 /// call. Currently it supports CUDA and ROCm (HIP).
158 class ConvertHostRegisterOpToGpuRuntimeCallPattern
159     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
160 public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)161   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
162       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
163 
164 private:
165   LogicalResult
166   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
167                   ConversionPatternRewriter &rewriter) const override;
168 };
169 
170 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
171 /// call. Currently it supports CUDA and ROCm (HIP).
172 class ConvertAllocOpToGpuRuntimeCallPattern
173     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
174 public:
ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)175   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
176       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
177 
178 private:
179   LogicalResult
180   matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
181                   ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
185 /// call. Currently it supports CUDA and ROCm (HIP).
186 class ConvertDeallocOpToGpuRuntimeCallPattern
187     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
188 public:
ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)189   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
190       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
191 
192 private:
193   LogicalResult
194   matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
195                   ConversionPatternRewriter &rewriter) const override;
196 };
197 
198 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
199 /// call. Currently it supports CUDA and ROCm (HIP).
200 class ConvertWaitOpToGpuRuntimeCallPattern
201     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
202 public:
ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)203   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
204       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
205 
206 private:
207   LogicalResult
208   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
209                   ConversionPatternRewriter &rewriter) const override;
210 };
211 
212 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
213 /// call. Currently it supports CUDA and ROCm (HIP).
214 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
215     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
216 public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)217   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
218       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
219 
220 private:
221   LogicalResult
222   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
223                   ConversionPatternRewriter &rewriter) const override;
224 };
225 
226 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
227 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
228 ///
229 /// In essence, a gpu.launch_func operations gets compiled into the following
230 /// sequence of runtime calls:
231 ///
232 /// * moduleLoad        -- loads the module given the cubin / hsaco data
233 /// * moduleGetFunction -- gets a handle to the actual kernel function
234 /// * getStreamHelper   -- initializes a new compute stream on GPU
235 /// * launchKernel      -- launches the kernel on a stream
236 /// * streamSynchronize -- waits for operations on the stream to finish
237 ///
238 /// Intermediate data structures are allocated on the stack.
239 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
240     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
241 public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter,StringRef gpuBinaryAnnotation)242   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
243                                              StringRef gpuBinaryAnnotation)
244       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
245         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
246 
247 private:
248   Value generateParamsArray(gpu::LaunchFuncOp launchOp,
249                             ArrayRef<Value> operands, OpBuilder &builder) const;
250   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
251                                    Location loc, OpBuilder &builder) const;
252 
253   LogicalResult
254   matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
255                   ConversionPatternRewriter &rewriter) const override;
256 
257   llvm::SmallString<32> gpuBinaryAnnotation;
258 };
259 
260 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
261   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
262 
matchAndRewrite(gpu::GPUModuleOp op,PatternRewriter & rewriter) const263   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
264                                 PatternRewriter &rewriter) const override {
265     // GPU kernel modules are no longer necessary since we have a global
266     // constant with the CUBIN, or HSACO data.
267     rewriter.eraseOp(op);
268     return success();
269   }
270 };
271 } // namespace
272 
runOnOperation()273 void GpuToLLVMConversionPass::runOnOperation() {
274   LLVMTypeConverter converter(&getContext());
275   OwningRewritePatternList patterns;
276   populateStdToLLVMConversionPatterns(converter, patterns);
277   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
278 
279   LLVMConversionTarget target(getContext());
280   if (failed(
281           applyPartialConversion(getOperation(), target, std::move(patterns))))
282     signalPassFailure();
283 }
284 
create(Location loc,OpBuilder & builder,ArrayRef<Value> arguments) const285 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
286                                          ArrayRef<Value> arguments) const {
287   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
288   auto function = [&] {
289     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
290       return function;
291     return OpBuilder(module.getBody()->getTerminator())
292         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
293   }();
294   return builder.create<LLVM::CallOp>(
295       loc, const_cast<LLVM::LLVMType &>(functionType).getFunctionResultType(),
296       builder.getSymbolRefAttr(function), arguments);
297 }
298 
299 // Returns whether all operands are of LLVM type.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)300 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
301                                      ConversionPatternRewriter &rewriter) {
302   if (!llvm::all_of(operands, [](Value value) {
303         return value.getType().isa<LLVM::LLVMType>();
304       }))
305     return rewriter.notifyMatchFailure(
306         op, "Cannot convert if operands aren't of LLVM type.");
307   return success();
308 }
309 
310 static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter & rewriter,gpu::AsyncOpInterface op)311 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
312                          gpu::AsyncOpInterface op) {
313   if (op.getAsyncDependencies().size() != 1)
314     return rewriter.notifyMatchFailure(
315         op, "Can only convert with exactly one async dependency.");
316 
317   if (!op.getAsyncToken())
318     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
319 
320   return success();
321 }
322 
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const323 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
324     gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
325     ConversionPatternRewriter &rewriter) const {
326   auto *op = hostRegisterOp.getOperation();
327   if (failed(areAllLLVMTypes(op, operands, rewriter)))
328     return failure();
329 
330   Location loc = op->getLoc();
331 
332   auto memRefType = hostRegisterOp.value().getType();
333   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
334   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
335 
336   auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
337                                                        operands, rewriter);
338   arguments.push_back(elementSize);
339   hostRegisterCallBuilder.create(loc, rewriter, arguments);
340 
341   rewriter.eraseOp(op);
342   return success();
343 }
344 
matchAndRewrite(gpu::AllocOp allocOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const345 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
346     gpu::AllocOp allocOp, ArrayRef<Value> operands,
347     ConversionPatternRewriter &rewriter) const {
348   MemRefType memRefType = allocOp.getType();
349 
350   if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
351       !isSupportedMemRefType(memRefType) ||
352       failed(isAsyncWithOneDependency(rewriter, allocOp)))
353     return failure();
354 
355   auto loc = allocOp.getLoc();
356 
357   // Get shape of the memref as values: static sizes are constant
358   // values and dynamic sizes are passed to 'alloc' as operands.
359   SmallVector<Value, 4> shape;
360   SmallVector<Value, 4> strides;
361   Value sizeBytes;
362   getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, shape, strides,
363                            sizeBytes);
364 
365   // Allocate the underlying buffer and store a pointer to it in the MemRef
366   // descriptor.
367   Type elementPtrType = this->getElementPtrType(memRefType);
368   auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());
369   auto stream = adaptor.asyncDependencies().front();
370   Value allocatedPtr =
371       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
372   allocatedPtr =
373       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
374 
375   // No alignment.
376   Value alignedPtr = allocatedPtr;
377 
378   // Create the MemRef descriptor.
379   auto memRefDescriptor = this->createMemRefDescriptor(
380       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
381 
382   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
383 
384   return success();
385 }
386 
matchAndRewrite(gpu::DeallocOp deallocOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const387 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
388     gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
389     ConversionPatternRewriter &rewriter) const {
390   if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
391       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
392     return failure();
393 
394   Location loc = deallocOp.getLoc();
395 
396   auto adaptor =
397       gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
398   Value pointer =
399       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
400   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
401   Value stream = adaptor.asyncDependencies().front();
402   deallocCallBuilder.create(loc, rewriter, {casted, stream});
403 
404   rewriter.replaceOp(deallocOp, {stream});
405   return success();
406 }
407 
408 // Converts `gpu.wait` to runtime calls. The operands are all CUDA or ROCm
409 // streams (i.e. void*). The converted op synchronizes the host with every
410 // stream and then destroys it. That is, it assumes that the stream is not used
411 // afterwards. In case this isn't correct, we will get a runtime error.
412 // Eventually, we will have a pass that guarantees this property.
matchAndRewrite(gpu::WaitOp waitOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const413 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
414     gpu::WaitOp waitOp, ArrayRef<Value> operands,
415     ConversionPatternRewriter &rewriter) const {
416   if (waitOp.asyncToken())
417     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
418 
419   Location loc = waitOp.getLoc();
420 
421   for (auto asyncDependency : operands)
422     streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
423   for (auto asyncDependency : operands)
424     streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});
425 
426   rewriter.eraseOp(waitOp);
427   return success();
428 }
429 
430 // Converts `gpu.wait async` to runtime calls. The result is a new stream that
431 // is synchronized with all operands, which are CUDA or ROCm streams (i.e.
432 // void*). We create and record an event after the definition of the stream
433 // and make the new stream wait on that event before destroying it again. This
434 // assumes that there is no other use between the definition and this op, and
435 // the plan is to have a pass that guarantees this property.
matchAndRewrite(gpu::WaitOp waitOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const436 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
437     gpu::WaitOp waitOp, ArrayRef<Value> operands,
438     ConversionPatternRewriter &rewriter) const {
439   if (!waitOp.asyncToken())
440     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
441 
442   Location loc = waitOp.getLoc();
443 
444   auto insertionPoint = rewriter.saveInsertionPoint();
445   SmallVector<Value, 1> events;
446   for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
447     auto token = std::get<0>(pair);
448     if (auto *defOp = token.getDefiningOp()) {
449       rewriter.setInsertionPointAfter(defOp);
450     } else {
451       // If we can't find the defining op, we record the event at block start,
452       // which is late and therefore misses parallelism, but still valid.
453       rewriter.setInsertionPointToStart(waitOp->getBlock());
454     }
455     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
456     auto stream = std::get<1>(pair);
457     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
458     events.push_back(event);
459   }
460   rewriter.restoreInsertionPoint(insertionPoint);
461   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
462   for (auto event : events)
463     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
464   for (auto event : events)
465     eventDestroyCallBuilder.create(loc, rewriter, {event});
466   rewriter.replaceOp(waitOp, {stream});
467 
468   return success();
469 }
470 
471 // Creates a struct containing all kernel parameters on the stack and returns
472 // an array of type-erased pointers to the fields of the struct. The array can
473 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
474 // The generated code is essentially as follows:
475 //
476 // %struct = alloca(sizeof(struct { Parameters... }))
477 // %array = alloca(NumParameters * sizeof(void *))
478 // for (i : [0, NumParameters))
479 //   %fieldPtr = llvm.getelementptr %struct[0, i]
480 //   llvm.store parameters[i], %fieldPtr
481 //   %elementPtr = llvm.getelementptr %array[i]
482 //   llvm.store %fieldPtr, %elementPtr
483 // return %array
generateParamsArray(gpu::LaunchFuncOp launchOp,ArrayRef<Value> operands,OpBuilder & builder) const484 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
485     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
486     OpBuilder &builder) const {
487   auto loc = launchOp.getLoc();
488   auto numKernelOperands = launchOp.getNumKernelOperands();
489   auto arguments = getTypeConverter()->promoteOperands(
490       loc, launchOp.getOperands().take_back(numKernelOperands),
491       operands.take_back(numKernelOperands), builder);
492   auto numArguments = arguments.size();
493   SmallVector<LLVM::LLVMType, 4> argumentTypes;
494   argumentTypes.reserve(numArguments);
495   for (auto argument : arguments)
496     argumentTypes.push_back(argument.getType().cast<LLVM::LLVMType>());
497   auto structType = LLVM::LLVMType::createStructTy(argumentTypes, StringRef());
498   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
499                                               builder.getI32IntegerAttr(1));
500   auto structPtr = builder.create<LLVM::AllocaOp>(
501       loc, structType.getPointerTo(), one, /*alignment=*/0);
502   auto arraySize = builder.create<LLVM::ConstantOp>(
503       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
504   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
505                                                  arraySize, /*alignment=*/0);
506   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
507                                                builder.getI32IntegerAttr(0));
508   for (auto en : llvm::enumerate(arguments)) {
509     auto index = builder.create<LLVM::ConstantOp>(
510         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
511     auto fieldPtr = builder.create<LLVM::GEPOp>(
512         loc, argumentTypes[en.index()].getPointerTo(), structPtr,
513         ArrayRef<Value>{zero, index.getResult()});
514     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
515     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
516                                                   arrayPtr, index.getResult());
517     auto casted =
518         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
519     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
520   }
521   return arrayPtr;
522 }
523 
524 // Generates an LLVM IR dialect global that contains the name of the given
525 // kernel function as a C string, and returns a pointer to its beginning.
526 // The code is essentially:
527 //
528 // llvm.global constant @kernel_name("function_name\00")
529 // func(...) {
530 //   %0 = llvm.addressof @kernel_name
531 //   %1 = llvm.constant (0 : index)
532 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
533 // }
generateKernelNameConstant(StringRef moduleName,StringRef name,Location loc,OpBuilder & builder) const534 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
535     StringRef moduleName, StringRef name, Location loc,
536     OpBuilder &builder) const {
537   // Make sure the trailing zero is included in the constant.
538   std::vector<char> kernelName(name.begin(), name.end());
539   kernelName.push_back('\0');
540 
541   std::string globalName =
542       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
543   return LLVM::createGlobalString(
544       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
545       LLVM::Linkage::Internal);
546 }
547 
548 // Emits LLVM IR to launch a kernel function. Expects the module that contains
549 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
550 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
551 //
552 // %0 = call %binarygetter
553 // %1 = call %moduleLoad(%0)
554 // %2 = <see generateKernelNameConstant>
555 // %3 = call %moduleGetFunction(%1, %2)
556 // %4 = call %streamCreate()
557 // %5 = <see generateParamsArray>
558 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
559 // call %streamSynchronize(%4)
560 // call %streamDestroy(%4)
561 // call %moduleUnload(%1)
562 //
563 // If the op is async, the stream corresponds to the (single) async dependency
564 // as well as the async token the op produces.
matchAndRewrite(gpu::LaunchFuncOp launchOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const565 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
566     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
567     ConversionPatternRewriter &rewriter) const {
568   if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
569     return failure();
570 
571   if (launchOp.asyncDependencies().size() > 1)
572     return rewriter.notifyMatchFailure(
573         launchOp, "Cannot convert with more than one async dependency.");
574 
575   // Fail when the synchronous version of the op has async dependencies. The
576   // lowering destroys the stream, and we do not want to check that there is no
577   // use of the stream after this op.
578   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
579     return rewriter.notifyMatchFailure(
580         launchOp, "Cannot convert non-async op with async dependencies.");
581 
582   Location loc = launchOp.getLoc();
583 
584   // Create an LLVM global with CUBIN extracted from the kernel annotation and
585   // obtain a pointer to the first byte in it.
586   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
587       launchOp, launchOp.getKernelModuleName());
588   assert(kernelModule && "expected a kernel module");
589 
590   auto binaryAttr =
591       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
592   if (!binaryAttr) {
593     kernelModule.emitOpError()
594         << "missing " << gpuBinaryAnnotation << " attribute";
595     return failure();
596   }
597 
598   SmallString<128> nameBuffer(kernelModule.getName());
599   nameBuffer.append(kGpuBinaryStorageSuffix);
600   Value data =
601       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
602                                binaryAttr.getValue(), LLVM::Linkage::Internal);
603 
604   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
605   // Get the function from the module. The name corresponds to the name of
606   // the kernel function.
607   auto kernelName = generateKernelNameConstant(
608       launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
609   auto function = moduleGetFunctionCallBuilder.create(
610       loc, rewriter, {module.getResult(0), kernelName});
611   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
612                                                 rewriter.getI32IntegerAttr(0));
613   auto adaptor =
614       gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
615   Value stream =
616       adaptor.asyncDependencies().empty()
617           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
618           : adaptor.asyncDependencies().front();
619   // Create array of pointers to kernel arguments.
620   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
621   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
622   launchKernelCallBuilder.create(loc, rewriter,
623                                  {function.getResult(0), launchOp.gridSizeX(),
624                                   launchOp.gridSizeY(), launchOp.gridSizeZ(),
625                                   launchOp.blockSizeX(), launchOp.blockSizeY(),
626                                   launchOp.blockSizeZ(),
627                                   /*sharedMemBytes=*/zero, stream, kernelParams,
628                                   /*extra=*/nullpointer});
629 
630   if (launchOp.asyncToken()) {
631     // Async launch: make dependent ops use the same stream.
632     rewriter.replaceOp(launchOp, {stream});
633   } else {
634     // Synchronize with host and destroy stream. This must be the stream created
635     // above (with no other uses) because we check that the synchronous version
636     // does not have any async dependencies.
637     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
638     streamDestroyCallBuilder.create(loc, rewriter, stream);
639     rewriter.eraseOp(launchOp);
640   }
641   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
642 
643   return success();
644 }
645 
646 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation)647 mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
648   return std::make_unique<GpuToLLVMConversionPass>(gpuBinaryAnnotation);
649 }
650 
populateGpuToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,StringRef gpuBinaryAnnotation)651 void mlir::populateGpuToLLVMConversionPatterns(
652     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
653     StringRef gpuBinaryAnnotation) {
654   converter.addConversion(
655       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
656         return LLVM::LLVMType::getInt8PtrTy(context);
657       });
658   patterns.insert<ConvertAllocOpToGpuRuntimeCallPattern,
659                   ConvertDeallocOpToGpuRuntimeCallPattern,
660                   ConvertHostRegisterOpToGpuRuntimeCallPattern,
661                   ConvertWaitAsyncOpToGpuRuntimeCallPattern,
662                   ConvertWaitOpToGpuRuntimeCallPattern>(converter);
663   patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
664       converter, gpuBinaryAnnotation);
665   patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
666 }
667