• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
27 #include "mlir/Dialect/Math/IR/Math.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 
36 // This file implements the logic to convert disc ral ops to llvm dialect
37 
38 namespace mlir {
39 namespace disc_ral {
40 
41 using LLVM::GlobalOp;
42 using LLVM::LLVMFuncOp;
43 using StrT = SmallString<128>;
44 
45 namespace {
46 
47 constexpr const char* kRalDispatchFunctionName = "disc_ral_call";
48 constexpr const char* kGpuBinaryAttrName = "gpu.binary_blob";
49 constexpr const char* kRalGpuLaunch = "ral_kernel_launch";
50 
51 // Encodes a mlir type and appends the encoding to the string buffer `out`.
getTypeEncoding(MLIRContext * ctx,Type t,StrT & out)52 LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) {
53   Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
54   Type llvm_pointer_pointer_type =
55       LLVM::LLVMPointerType::get(llvm_pointer_type);
56   if (auto memref_type = t.dyn_cast<MemRefType>()) {
57     out.append(
58         Twine("m").concat(Twine(memref_type.getRank()).concat("d")).str());
59     return getTypeEncoding(ctx, memref_type.getElementType(), out);
60   } else if (auto int_type = t.dyn_cast<IntegerType>()) {
61     out.append(Twine("i").concat(Twine(int_type.getWidth())).str());
62   } else if (auto fp_type = t.dyn_cast<FloatType>()) {
63     out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
64   } else if (auto ctx_type = t.dyn_cast<RalExecutionContextType>() ||
65                              t == llvm_pointer_type) {
66     out.append("pvoid");
67   } else if (t == llvm_pointer_pointer_type) {
68     out.append("ppvoid");
69   } else if (t.isIndex()) {
70     // index is mapping to int64_t a.t.m. Re-visit this in case necessary.
71     out.append("i64");
72   } else {
73     // unknown type
74     return failure();
75   }
76   return success();
77 }
78 
79 // Encodes a ral_dispatch op and appends the encoding to the string buffer
80 // `out`. The format:
81 //   encoding = separator.join(target_name, device, inputs_encode,
82 //   outputs_encode)
83 //
84 //   separator = '___'
85 //
86 //   target_name: name of the external function to dispatch.
87 //
88 //   device: user defined string (e.g. cpu or gpu)
89 //
90 //   inputs_encode = type_separator.join([type_encoding for type in
91 //   input_types])
92 //
93 //   outputs_encode = type_separator.join([type_encoding for type
94 //   in output_types])
95 //
96 //   type_separator = '_'
getDispatchOpSignatureEncoding(DispatchOp dispatch_op,StrT & out)97 LogicalResult getDispatchOpSignatureEncoding(DispatchOp dispatch_op,
98                                              StrT& out) {
99   const char* separator = "___";
100   // append signature prefix
101   out.append(dispatch_op.call_target_name());
102   out.append(separator);
103 
104   // encode backend (device) info
105   out.append(dispatch_op.backend_config());
106   out.append(separator);
107 
108   // encode input types
109   Operation* op = dispatch_op.getOperation();
110   for (auto& en : llvm::enumerate(op->getOperandTypes())) {
111     if (en.index() != 0) out.append("_");
112     if (failed(getTypeEncoding(op->getContext(), en.value(), out)))
113       return failure();
114   }
115   out.append(separator);
116 
117   // encode output types
118   for (auto& en : llvm::enumerate(op->getResultTypes())) {
119     if (en.index() != 0) out.append("_");
120     if (failed(getTypeEncoding(op->getContext(), en.value(), out)))
121       return failure();
122   }
123   if (!op->getNumResults()) out.append("void");
124   return success();
125 }
126 
127 // Loads a global op at the current insertion point and returns the loaded
128 // value.
loadGlobalString(OpBuilder & builder,const Location & loc,GlobalOp globalOp)129 Value loadGlobalString(OpBuilder& builder, const Location& loc,
130                        GlobalOp globalOp) {
131   MLIRContext* ctx = builder.getContext();
132   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, globalOp);
133   Value cst0 = builder.create<LLVM::ConstantOp>(
134       loc, IntegerType::get(ctx, 64),
135       builder.getIntegerAttr(builder.getIndexType(), 0));
136   return builder.create<LLVM::GEPOp>(
137       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
138       ValueRange{cst0, cst0});
139 }
140 
141 // Returns true if the globalOp has the same value as `value`.
checkGlobalOpContent(GlobalOp globalOp,StringRef value)142 bool checkGlobalOpContent(GlobalOp globalOp, StringRef value) {
143   Optional<Attribute> optValue = globalOp.value();
144   if (!optValue) return false;
145 
146   StringAttr attr = (*optValue).cast<StringAttr>();
147   if (!attr) return false;
148 
149   return attr.getValue() == value;
150 }
151 
152 // Creates a global const string op named `name` using the value if not exists
153 // and returns the Loaded value of this global op.
loadOrCreateGlobalString(PatternRewriter & rewriter,SymbolTable & symbol_table,Operation * op,StringRef name,StringRef value)154 Value loadOrCreateGlobalString(PatternRewriter& rewriter,
155                                SymbolTable& symbol_table, Operation* op,
156                                StringRef name, StringRef value) {
157   ModuleOp module = op->getParentOfType<ModuleOp>();
158   GlobalOp globalOp = symbol_table.lookup<GlobalOp>(name);
159   if (!globalOp) {
160     OpBuilder::InsertionGuard guard(rewriter);
161     OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
162     rewriter.setInsertionPointToStart(module.getBody());
163 
164     auto type = LLVM::LLVMArrayType::get(IntegerType::get(op->getContext(), 8),
165                                          value.size());
166     globalOp = rewriter.create<LLVM::GlobalOp>(
167         op->getLoc(), type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
168         rewriter.getStringAttr(value), /*alignment=*/0);
169 
170     // Update the symbol table
171     symbol_table.insert(globalOp);
172 
173     rewriter.restoreInsertionPoint(ip);
174   } else {
175     assert(checkGlobalOpContent(globalOp, value));
176   }
177 
178   return loadGlobalString(rewriter, op->getLoc(), globalOp);
179 }
180 
181 // Converts a ral.dispatch_op to its llvm format.
182 class DispatchOpToLLVMPattern : public ConvertOpToLLVMPattern<DispatchOp> {
183  public:
DispatchOpToLLVMPattern(LLVMTypeConverter & type_converter,SymbolTable & symbol_table)184   DispatchOpToLLVMPattern(LLVMTypeConverter& type_converter,
185                           SymbolTable& symbol_table)
186       : ConvertOpToLLVMPattern<DispatchOp>(type_converter),
187         symbol_table_(symbol_table) {}
188 
189   // Returns the ral dispatch function and inserts the declaration if not found.
190   LLVMFuncOp getOrInsertDispatchFunction(PatternRewriter& rewriter,
191                                          Operation* op) const;
192 
193   // Packs the inputs and outputs into a type-erased pointer array.
194   // For example, `int func(int)` -> `void func(void* args[]) where args =
195   // {in_ptr, out_ptr}`
196   Value rewriteInsOutsOfDispatchOp(DispatchOp dispatch_op,
197                                    ArrayRef<Value> operands,
198                                    ConversionPatternRewriter& rewriter,
199                                    SmallVectorImpl<Value>& resultPtrs) const;
200 
201   LogicalResult matchAndRewrite(
202       DispatchOp dispatch_op, ArrayRef<Value> operands,
203       ConversionPatternRewriter& rewriter) const override;
204 
205  private:
206   SymbolTable& symbol_table_;
207 };
208 
209 // Returns the llvm function definition of ral dispatch op and creates it first
210 // if not exists.
getOrInsertDispatchFunction(PatternRewriter & rewriter,Operation * op) const211 LLVMFuncOp DispatchOpToLLVMPattern::getOrInsertDispatchFunction(
212     PatternRewriter& rewriter, Operation* op) const {
213   ModuleOp module = op->getParentOfType<ModuleOp>();
214   LLVMFuncOp func = symbol_table_.lookup<LLVMFuncOp>(kRalDispatchFunctionName);
215 
216   if (func) return func;
217 
218   // Try to insert the function since it's not found.
219   OpBuilder::InsertionGuard guard(rewriter);
220   OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
221   rewriter.setInsertionPointToStart(module.getBody());
222   Type llvm_pointer_type =
223       LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
224   Type llvm_pointer_pointer_type =
225       LLVM::LLVMPointerType::get(llvm_pointer_type);
226   func = rewriter.create<LLVMFuncOp>(
227       op->getLoc(), kRalDispatchFunctionName,
228       LLVM::LLVMFunctionType::get(
229           getVoidType(),
230           {
231               llvm_pointer_type,        /* ral_context_t */
232               llvm_pointer_type,        /* void* call_target_name */
233               llvm_pointer_pointer_type /* void** args */
234           },
235           /*isVarArg=*/false));
236 
237   symbol_table_.insert(func);
238 
239   rewriter.restoreInsertionPoint(ip);
240 
241   return func;
242 }
243 
244 // Packs the original inputs and outputs of the ral dispatch op to a uniform
245 // format.
246 //
247 // %struct = alloca(sizeof(struct { Parameters..., Results..., }))
248 // %array = alloca((NumParameters + NumResult) * sizeof(void *))
249 // for (i : [0, NumParameters))
250 //   %fieldPtr = llvm.getelementptr %struct[0, i]
251 //   llvm.store parameters[i], %fieldPtr
252 //   %elementPtr = llvm.getelementptr %array[i]
253 //   llvm.store %fieldPtr, %elementPtr
254 // for (i : [NumParameters, NumParameters + NumResult))
255 //   %fieldPtr = llvm.getelementptr %struct[0, i]
256 //   %elementPtr = llvm.getelementptr %array[i]
257 //   llvm.store %fieldPtr, %elementPtr
258 // return %array
rewriteInsOutsOfDispatchOp(DispatchOp dispatch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & resultPtrs) const259 Value DispatchOpToLLVMPattern::rewriteInsOutsOfDispatchOp(
260     DispatchOp dispatch_op, ArrayRef<Value> operands,
261     ConversionPatternRewriter& rewriter,
262     SmallVectorImpl<Value>& resultPtrs) const {
263   MLIRContext* ctx = rewriter.getContext();
264   Location loc = dispatch_op.getLoc();
265 
266   Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
267   Type llvm_pointer_pointer_type =
268       LLVM::LLVMPointerType::get(llvm_pointer_type);
269   Type llvm_int32_type = IntegerType::get(ctx, 32);
270 
271   Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
272                                                  rewriter.getI32IntegerAttr(0));
273   Value one = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
274                                                 rewriter.getI32IntegerAttr(1));
275 
276   SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
277       loc, dispatch_op.getOperands(), operands, rewriter);
278   SmallVector<Type, 4> argument_types;
279   for (auto argument : arguments) argument_types.push_back(argument.getType());
280   for (auto resultType : dispatch_op.getResultTypes())
281     argument_types.push_back(getTypeConverter()->convertType(resultType));
282 
283   auto struct_type =
284       LLVM::LLVMStructType::getNewIdentified(ctx, StringRef(), argument_types);
285   Value struct_ptr = rewriter.create<LLVM::AllocaOp>(
286       loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
287   Value array_size = rewriter.create<LLVM::ConstantOp>(
288       loc, llvm_int32_type, rewriter.getI32IntegerAttr(argument_types.size()));
289   Value array_ptr = rewriter.create<LLVM::AllocaOp>(
290       loc, llvm_pointer_pointer_type, array_size, /*alignment=*/0);
291 
292   for (auto en : llvm::enumerate(argument_types)) {
293     Value index = rewriter.create<LLVM::ConstantOp>(
294         loc, llvm_int32_type, rewriter.getI32IntegerAttr(en.index()));
295     Value field_ptr = rewriter.create<LLVM::GEPOp>(
296         loc, LLVM::LLVMPointerType::get(en.value()), struct_ptr,
297         ArrayRef<Value>{zero, index});
298     if (en.index() < arguments.size()) {
299       rewriter.create<LLVM::StoreOp>(loc, arguments[en.index()], field_ptr);
300     } else {
301       resultPtrs.push_back(field_ptr);
302     }
303 
304     Value element_ptr = rewriter.create<LLVM::GEPOp>(
305         loc, llvm_pointer_pointer_type, array_ptr, index);
306     Value casted =
307         rewriter.create<LLVM::BitcastOp>(loc, llvm_pointer_type, field_ptr);
308     rewriter.create<LLVM::StoreOp>(loc, casted, element_ptr);
309   }
310 
311   return array_ptr;
312 }
313 
matchAndRewrite(DispatchOp dispatch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const314 LogicalResult DispatchOpToLLVMPattern::matchAndRewrite(
315     DispatchOp dispatch_op, ArrayRef<Value> operands,
316     ConversionPatternRewriter& rewriter) const {
317   StrT target_name;
318   if (failed(getDispatchOpSignatureEncoding(dispatch_op, target_name))) {
319     dispatch_op->emitError("unknown types in the dispatch op");
320     return failure();
321   }
322 
323   // Make sure the trailing zero is included in the constant.
324   target_name.push_back('\0');
325 
326   Operation* op = dispatch_op.getOperation();
327   Location loc = op->getLoc();
328   DispatchOp::Adaptor adaptor(operands);
329   SmallVector<Value, 3> callOpOperands;
330   LLVMFuncOp dispatch_func = getOrInsertDispatchFunction(rewriter, op);
331 
332   SmallVector<Value, 1> resultPtrs;
333   Value packedArgs =
334       rewriteInsOutsOfDispatchOp(dispatch_op, operands, rewriter, resultPtrs);
335 
336   // the first argument is ral_context
337   callOpOperands.push_back(adaptor.ctx());
338   // the second argument is the target name
339   callOpOperands.push_back(loadOrCreateGlobalString(
340       rewriter, symbol_table_, op, target_name.str().drop_back(),
341       target_name.str()));
342   // the third argument is the args for target function
343   callOpOperands.push_back(packedArgs);
344 
345   rewriter.create<LLVM::CallOp>(loc, llvm::None,
346                                 rewriter.getSymbolRefAttr(dispatch_func),
347                                 callOpOperands);
348 
349   SmallVector<Value, 1> results;
350   llvm::transform(resultPtrs, std::back_inserter(results), [&](Value v) {
351     return rewriter.create<LLVM::LoadOp>(loc, v);
352   });
353 
354   rewriter.replaceOp(op, results);
355 
356   return success();
357 }
358 
359 // A rewrite pattern to convert gpu.launch_func operations into corresponding
360 // runtime wrapper calls (modeled by ral.dispatch ops)
361 class ConvertLaunchFuncOpToRalCallPattern
362     : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
363  public:
ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter & type_converter,SymbolTable & symbol_table)364   ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter,
365                                       SymbolTable& symbol_table)
366       : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
367         symbol_table_(symbol_table) {}
368 
369  private:
370   Value generateParamsArray(gpu::LaunchFuncOp launch_op,
371                             ArrayRef<Value> operands, OpBuilder& builder) const;
372   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
373                                    Location loc, OpBuilder& builder) const;
374 
375   LogicalResult matchAndRewrite(
376       gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
377       ConversionPatternRewriter& rewriter) const override;
378 
379   SymbolTable& symbol_table_;
380 };
381 
382 // Creates a struct containing all kernel parameters on the stack and returns
383 // an array of type-erased pointers to the fields of the struct. The array can
384 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
385 // The generated code is essentially as follows:
386 //
387 // %struct = alloca(sizeof(struct { Parameters... }))
388 // %array = alloca(NumParameters * sizeof(void *))
389 // for (i : [0, NumParameters))
390 //   %fieldPtr = llvm.getelementptr %struct[0, i]
391 //   llvm.store parameters[i], %fieldPtr
392 //   %elementPtr = llvm.getelementptr %array[i]
393 //   llvm.store %fieldPtr, %elementPtr
394 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const395 Value ConvertLaunchFuncOpToRalCallPattern::generateParamsArray(
396     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
397     OpBuilder& builder) const {
398   MLIRContext* ctx = builder.getContext();
399   Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
400   Type llvm_pointer_pointer_type =
401       LLVM::LLVMPointerType::get(llvm_pointer_type);
402   Type llvm_int32_type = IntegerType::get(ctx, 32);
403 
404   Location loc = launch_op.getLoc();
405   int num_kernel_operands = launch_op.getNumKernelOperands();
406   auto arguments = getTypeConverter()->promoteOperands(
407       loc, launch_op.getOperands().take_back(num_kernel_operands),
408       operands.take_back(num_kernel_operands), builder);
409   int num_arguments = static_cast<int>(arguments.size());
410   SmallVector<Type, 4> argument_types;
411   argument_types.reserve(num_arguments);
412   for (auto argument : arguments) argument_types.push_back(argument.getType());
413   auto struct_type =
414       LLVM::LLVMStructType::getNewIdentified(ctx, StringRef(), argument_types);
415   Value one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type,
416                                                builder.getI32IntegerAttr(1));
417   Value struct_ptr = builder.create<LLVM::AllocaOp>(
418       loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
419   Value array_size = builder.create<LLVM::ConstantOp>(
420       loc, llvm_int32_type, builder.getI32IntegerAttr(num_arguments));
421   Value array_ptr = builder.create<LLVM::AllocaOp>(
422       loc, llvm_pointer_pointer_type, array_size, /*alignment=*/0);
423   Value zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type,
424                                                 builder.getI32IntegerAttr(0));
425   for (auto en : llvm::enumerate(arguments)) {
426     Value index = builder.create<LLVM::ConstantOp>(
427         loc, llvm_int32_type, builder.getI32IntegerAttr(en.index()));
428     Value field_ptr = builder.create<LLVM::GEPOp>(
429         loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
430         ArrayRef<Value>{zero, index});
431     builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
432     Value element_ptr = builder.create<LLVM::GEPOp>(
433         loc, llvm_pointer_pointer_type, array_ptr, index);
434     Value casted =
435         builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type, field_ptr);
436     builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
437   }
438   return array_ptr;
439 }
440 
441 // Emits LLVM IR to launch a kernel function. Expects the module that contains
442 // the compiled kernel function as a cubin in the `kRalGpuLaunch` attribute.
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const443 LogicalResult ConvertLaunchFuncOpToRalCallPattern::matchAndRewrite(
444     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
445     ConversionPatternRewriter& rewriter) const {
446   if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
447     return rewriter.notifyMatchFailure(
448         launch_op, "Cannot convert with async dependency or result.");
449   }
450 
451   // Create an LLVM global with CUBIN extracted from the kernel annotation and
452   // obtain a pointer to the first byte in it.
453   auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
454       launch_op, launch_op.getKernelModuleName());
455   if (!kernel_module) {
456     launch_op.emitOpError() << "cannot find corresponding kernel module.";
457     return failure();
458   }
459 
460   auto binary_attr =
461       kernel_module->getAttrOfType<StringAttr>(kGpuBinaryAttrName);
462   if (!binary_attr) {
463     kernel_module.emitOpError()
464         << "missing " << kGpuBinaryAttrName << " attribute";
465     return failure();
466   }
467 
468   Operation* op = launch_op.getOperation();
469   Location loc = launch_op.getLoc();
470 
471   // Create a global for the module blob.
472   StrT name_buffer(kernel_module.getName());
473   name_buffer.append("_blob");
474 
475   Value module_blob = loadOrCreateGlobalString(
476       rewriter, symbol_table_, op, name_buffer.str(), binary_attr.getValue());
477 
478   // Make sure the trailing zero is included in the constant.
479   auto kernel_name = launch_op.getKernelName();
480   SmallString<128> kernel_name_buffer(kernel_name);
481   kernel_name_buffer.push_back('\0');
482 
483   // Create a global for the kernel name.
484   SmallString<128> kernel_name_global_name_buffer;
485   auto kernel_name_global_name =
486       (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
487           .toStringRef(kernel_name_global_name_buffer);
488   Value kernel_name_global = loadOrCreateGlobalString(
489       rewriter, symbol_table_, op, kernel_name_global_name,
490       kernel_name_buffer.str());
491 
492   auto adaptor =
493       gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
494 
495   // The Ral Context is the first argument of the surrounding LLVMFunc.
496   Value context_arg =
497       launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
498   auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
499 
500   Type llvm_int32_type = IntegerType::get(rewriter.getContext(), 32);
501   Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
502                                                  rewriter.getI32IntegerAttr(0));
503   // clang-format off
504   // TODO(disc): we use the default stream a.t.m. Implement a stream assignment
505   // algo in case necessary.
506   SmallVector<Value, 12> newOperands{
507       module_blob, /* gpu module string */
508       kernel_name_global, /* name of the kernel to launch */
509       adaptor.gridSizeX(), adaptor.gridSizeY(), adaptor.gridSizeZ(),
510       adaptor.blockSizeX(), adaptor.blockSizeY(), adaptor.blockSizeZ(),
511       zero, /* sharedMemBytes */
512       zero, /* gpu stream index */
513       kernel_params /* params for the kernel to launch */
514   };
515   // clang-format on
516 
517   rewriter.replaceOpWithNewOp<disc_ral::DispatchOp>(
518       launch_op, llvm::None, context_arg, newOperands, kRalGpuLaunch, false,
519       "cpu");
520   return success();
521 }
522 
523 class RalToLLVMPass : public RalToLLVMPassBase<RalToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const524   void getDependentDialects(DialectRegistry& registry) const override {
525     registry.insert<LLVM::LLVMDialect>();
526   }
527 
528  public:
runOnOperation()529   void runOnOperation() override {
530     ModuleOp m = getOperation();
531     SymbolTable symbol_table(m);
532 
533     // Populate type conversions.
534     MLIRContext* ctx = m.getContext();
535     LLVMTypeConverter type_converter(ctx);
536     type_converter.addConversion([&](RalExecutionContextType type) {
537       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
538     });
539 
540     // Populate patterns.
541     RewritePatternSet patterns(&getContext());
542     populateStdExpandOpsPatterns(patterns);
543     populateStdToLLVMConversionPatterns(type_converter, patterns);
544     populateDiscRalToLLVMConversionPatterns(&type_converter, &symbol_table,
545                                             &patterns);
546 
547     // Set target.
548     ConversionTarget target(*ctx);
549     target.addLegalDialect<LLVM::LLVMDialect>();
550     target.addIllegalDialect<StandardOpsDialect, gpu::GPUDialect,
551                              disc_ral::RalDialect, math::MathDialect>();
552     target.addIllegalOp<UnrealizedConversionCastOp>();
553     // Mark modules as legal.
554     target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
555     // Do not look into gpu modules, only consider host-side.
556     target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
557 
558     if (failed(applyFullConversion(m, target, std::move(patterns)))) {
559       signalPassFailure();
560     }
561 
562     // Finally, strip the GPU modules, as they are no longer needed.
563     for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
564       op.erase();
565     }
566   }
567 };
568 
569 }  // namespace
570 
populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter * converter,SymbolTable * symbol_table,RewritePatternSet * patterns)571 void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
572                                              SymbolTable* symbol_table,
573                                              RewritePatternSet* patterns) {
574   // clang-format off
575   patterns->insert<
576       ConvertLaunchFuncOpToRalCallPattern,
577       DispatchOpToLLVMPattern
578     >(*converter, *symbol_table);
579   // clang-format on
580 }
581 
createRalToLLVMPass()582 std::unique_ptr<OperationPass<ModuleOp>> createRalToLLVMPass() {
583   return std::make_unique<RalToLLVMPass>();
584 }
585 
586 }  // namespace disc_ral
587 }  // namespace mlir
588