1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
27 #include "mlir/Dialect/Math/IR/Math.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/Transforms/DialectConversion.h"
35
36 // This file implements the logic to convert disc ral ops to llvm dialect
37
38 namespace mlir {
39 namespace disc_ral {
40
41 using LLVM::GlobalOp;
42 using LLVM::LLVMFuncOp;
43 using StrT = SmallString<128>;
44
45 namespace {
46
47 constexpr const char* kRalDispatchFunctionName = "disc_ral_call";
48 constexpr const char* kGpuBinaryAttrName = "gpu.binary_blob";
49 constexpr const char* kRalGpuLaunch = "ral_kernel_launch";
50
51 // Encodes a mlir type and appends the encoding to the string buffer `out`.
getTypeEncoding(MLIRContext * ctx,Type t,StrT & out)52 LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) {
53 Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
54 Type llvm_pointer_pointer_type =
55 LLVM::LLVMPointerType::get(llvm_pointer_type);
56 if (auto memref_type = t.dyn_cast<MemRefType>()) {
57 out.append(
58 Twine("m").concat(Twine(memref_type.getRank()).concat("d")).str());
59 return getTypeEncoding(ctx, memref_type.getElementType(), out);
60 } else if (auto int_type = t.dyn_cast<IntegerType>()) {
61 out.append(Twine("i").concat(Twine(int_type.getWidth())).str());
62 } else if (auto fp_type = t.dyn_cast<FloatType>()) {
63 out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
64 } else if (auto ctx_type = t.dyn_cast<RalExecutionContextType>() ||
65 t == llvm_pointer_type) {
66 out.append("pvoid");
67 } else if (t == llvm_pointer_pointer_type) {
68 out.append("ppvoid");
69 } else if (t.isIndex()) {
70 // index is mapping to int64_t a.t.m. Re-visit this in case necessary.
71 out.append("i64");
72 } else {
73 // unknown type
74 return failure();
75 }
76 return success();
77 }
78
79 // Encodes a ral_dispatch op and appends the encoding to the string buffer
80 // `out`. The format:
81 // encoding = separator.join(target_name, device, inputs_encode,
82 // outputs_encode)
83 //
84 // separator = '___'
85 //
86 // target_name: name of the external function to dispatch.
87 //
88 // device: user defined string (e.g. cpu or gpu)
89 //
90 // inputs_encode = type_separator.join([type_encoding for type in
91 // input_types])
92 //
93 // outputs_encode = type_separator.join([type_encoding for type
94 // in output_types])
95 //
96 // type_separator = '_'
getDispatchOpSignatureEncoding(DispatchOp dispatch_op,StrT & out)97 LogicalResult getDispatchOpSignatureEncoding(DispatchOp dispatch_op,
98 StrT& out) {
99 const char* separator = "___";
100 // append signature prefix
101 out.append(dispatch_op.call_target_name());
102 out.append(separator);
103
104 // encode backend (device) info
105 out.append(dispatch_op.backend_config());
106 out.append(separator);
107
108 // encode input types
109 Operation* op = dispatch_op.getOperation();
110 for (auto& en : llvm::enumerate(op->getOperandTypes())) {
111 if (en.index() != 0) out.append("_");
112 if (failed(getTypeEncoding(op->getContext(), en.value(), out)))
113 return failure();
114 }
115 out.append(separator);
116
117 // encode output types
118 for (auto& en : llvm::enumerate(op->getResultTypes())) {
119 if (en.index() != 0) out.append("_");
120 if (failed(getTypeEncoding(op->getContext(), en.value(), out)))
121 return failure();
122 }
123 if (!op->getNumResults()) out.append("void");
124 return success();
125 }
126
127 // Loads a global op at the current insertion point and returns the loaded
128 // value.
loadGlobalString(OpBuilder & builder,const Location & loc,GlobalOp globalOp)129 Value loadGlobalString(OpBuilder& builder, const Location& loc,
130 GlobalOp globalOp) {
131 MLIRContext* ctx = builder.getContext();
132 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, globalOp);
133 Value cst0 = builder.create<LLVM::ConstantOp>(
134 loc, IntegerType::get(ctx, 64),
135 builder.getIntegerAttr(builder.getIndexType(), 0));
136 return builder.create<LLVM::GEPOp>(
137 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
138 ValueRange{cst0, cst0});
139 }
140
141 // Returns true if the globalOp has the same value as `value`.
checkGlobalOpContent(GlobalOp globalOp,StringRef value)142 bool checkGlobalOpContent(GlobalOp globalOp, StringRef value) {
143 Optional<Attribute> optValue = globalOp.value();
144 if (!optValue) return false;
145
146 StringAttr attr = (*optValue).cast<StringAttr>();
147 if (!attr) return false;
148
149 return attr.getValue() == value;
150 }
151
152 // Creates a global const string op named `name` using the value if not exists
153 // and returns the Loaded value of this global op.
loadOrCreateGlobalString(PatternRewriter & rewriter,SymbolTable & symbol_table,Operation * op,StringRef name,StringRef value)154 Value loadOrCreateGlobalString(PatternRewriter& rewriter,
155 SymbolTable& symbol_table, Operation* op,
156 StringRef name, StringRef value) {
157 ModuleOp module = op->getParentOfType<ModuleOp>();
158 GlobalOp globalOp = symbol_table.lookup<GlobalOp>(name);
159 if (!globalOp) {
160 OpBuilder::InsertionGuard guard(rewriter);
161 OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
162 rewriter.setInsertionPointToStart(module.getBody());
163
164 auto type = LLVM::LLVMArrayType::get(IntegerType::get(op->getContext(), 8),
165 value.size());
166 globalOp = rewriter.create<LLVM::GlobalOp>(
167 op->getLoc(), type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
168 rewriter.getStringAttr(value), /*alignment=*/0);
169
170 // Update the symbol table
171 symbol_table.insert(globalOp);
172
173 rewriter.restoreInsertionPoint(ip);
174 } else {
175 assert(checkGlobalOpContent(globalOp, value));
176 }
177
178 return loadGlobalString(rewriter, op->getLoc(), globalOp);
179 }
180
181 // Converts a ral.dispatch_op to its llvm format.
182 class DispatchOpToLLVMPattern : public ConvertOpToLLVMPattern<DispatchOp> {
183 public:
DispatchOpToLLVMPattern(LLVMTypeConverter & type_converter,SymbolTable & symbol_table)184 DispatchOpToLLVMPattern(LLVMTypeConverter& type_converter,
185 SymbolTable& symbol_table)
186 : ConvertOpToLLVMPattern<DispatchOp>(type_converter),
187 symbol_table_(symbol_table) {}
188
189 // Returns the ral dispatch function and inserts the declaration if not found.
190 LLVMFuncOp getOrInsertDispatchFunction(PatternRewriter& rewriter,
191 Operation* op) const;
192
193 // Packs the inputs and outputs into a type-erased pointer array.
194 // For example, `int func(int)` -> `void func(void* args[]) where args =
195 // {in_ptr, out_ptr}`
196 Value rewriteInsOutsOfDispatchOp(DispatchOp dispatch_op,
197 ArrayRef<Value> operands,
198 ConversionPatternRewriter& rewriter,
199 SmallVectorImpl<Value>& resultPtrs) const;
200
201 LogicalResult matchAndRewrite(
202 DispatchOp dispatch_op, ArrayRef<Value> operands,
203 ConversionPatternRewriter& rewriter) const override;
204
205 private:
206 SymbolTable& symbol_table_;
207 };
208
209 // Returns the llvm function definition of ral dispatch op and creates it first
210 // if not exists.
getOrInsertDispatchFunction(PatternRewriter & rewriter,Operation * op) const211 LLVMFuncOp DispatchOpToLLVMPattern::getOrInsertDispatchFunction(
212 PatternRewriter& rewriter, Operation* op) const {
213 ModuleOp module = op->getParentOfType<ModuleOp>();
214 LLVMFuncOp func = symbol_table_.lookup<LLVMFuncOp>(kRalDispatchFunctionName);
215
216 if (func) return func;
217
218 // Try to insert the function since it's not found.
219 OpBuilder::InsertionGuard guard(rewriter);
220 OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
221 rewriter.setInsertionPointToStart(module.getBody());
222 Type llvm_pointer_type =
223 LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
224 Type llvm_pointer_pointer_type =
225 LLVM::LLVMPointerType::get(llvm_pointer_type);
226 func = rewriter.create<LLVMFuncOp>(
227 op->getLoc(), kRalDispatchFunctionName,
228 LLVM::LLVMFunctionType::get(
229 getVoidType(),
230 {
231 llvm_pointer_type, /* ral_context_t */
232 llvm_pointer_type, /* void* call_target_name */
233 llvm_pointer_pointer_type /* void** args */
234 },
235 /*isVarArg=*/false));
236
237 symbol_table_.insert(func);
238
239 rewriter.restoreInsertionPoint(ip);
240
241 return func;
242 }
243
244 // Packs the original inputs and outputs of the ral dispatch op to a uniform
245 // format.
246 //
247 // %struct = alloca(sizeof(struct { Parameters..., Results..., }))
248 // %array = alloca((NumParameters + NumResult) * sizeof(void *))
249 // for (i : [0, NumParameters))
250 // %fieldPtr = llvm.getelementptr %struct[0, i]
251 // llvm.store parameters[i], %fieldPtr
252 // %elementPtr = llvm.getelementptr %array[i]
253 // llvm.store %fieldPtr, %elementPtr
254 // for (i : [NumParameters, NumParameters + NumResult))
255 // %fieldPtr = llvm.getelementptr %struct[0, i]
256 // %elementPtr = llvm.getelementptr %array[i]
257 // llvm.store %fieldPtr, %elementPtr
258 // return %array
rewriteInsOutsOfDispatchOp(DispatchOp dispatch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & resultPtrs) const259 Value DispatchOpToLLVMPattern::rewriteInsOutsOfDispatchOp(
260 DispatchOp dispatch_op, ArrayRef<Value> operands,
261 ConversionPatternRewriter& rewriter,
262 SmallVectorImpl<Value>& resultPtrs) const {
263 MLIRContext* ctx = rewriter.getContext();
264 Location loc = dispatch_op.getLoc();
265
266 Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
267 Type llvm_pointer_pointer_type =
268 LLVM::LLVMPointerType::get(llvm_pointer_type);
269 Type llvm_int32_type = IntegerType::get(ctx, 32);
270
271 Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
272 rewriter.getI32IntegerAttr(0));
273 Value one = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
274 rewriter.getI32IntegerAttr(1));
275
276 SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
277 loc, dispatch_op.getOperands(), operands, rewriter);
278 SmallVector<Type, 4> argument_types;
279 for (auto argument : arguments) argument_types.push_back(argument.getType());
280 for (auto resultType : dispatch_op.getResultTypes())
281 argument_types.push_back(getTypeConverter()->convertType(resultType));
282
283 auto struct_type =
284 LLVM::LLVMStructType::getNewIdentified(ctx, StringRef(), argument_types);
285 Value struct_ptr = rewriter.create<LLVM::AllocaOp>(
286 loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
287 Value array_size = rewriter.create<LLVM::ConstantOp>(
288 loc, llvm_int32_type, rewriter.getI32IntegerAttr(argument_types.size()));
289 Value array_ptr = rewriter.create<LLVM::AllocaOp>(
290 loc, llvm_pointer_pointer_type, array_size, /*alignment=*/0);
291
292 for (auto en : llvm::enumerate(argument_types)) {
293 Value index = rewriter.create<LLVM::ConstantOp>(
294 loc, llvm_int32_type, rewriter.getI32IntegerAttr(en.index()));
295 Value field_ptr = rewriter.create<LLVM::GEPOp>(
296 loc, LLVM::LLVMPointerType::get(en.value()), struct_ptr,
297 ArrayRef<Value>{zero, index});
298 if (en.index() < arguments.size()) {
299 rewriter.create<LLVM::StoreOp>(loc, arguments[en.index()], field_ptr);
300 } else {
301 resultPtrs.push_back(field_ptr);
302 }
303
304 Value element_ptr = rewriter.create<LLVM::GEPOp>(
305 loc, llvm_pointer_pointer_type, array_ptr, index);
306 Value casted =
307 rewriter.create<LLVM::BitcastOp>(loc, llvm_pointer_type, field_ptr);
308 rewriter.create<LLVM::StoreOp>(loc, casted, element_ptr);
309 }
310
311 return array_ptr;
312 }
313
matchAndRewrite(DispatchOp dispatch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const314 LogicalResult DispatchOpToLLVMPattern::matchAndRewrite(
315 DispatchOp dispatch_op, ArrayRef<Value> operands,
316 ConversionPatternRewriter& rewriter) const {
317 StrT target_name;
318 if (failed(getDispatchOpSignatureEncoding(dispatch_op, target_name))) {
319 dispatch_op->emitError("unknown types in the dispatch op");
320 return failure();
321 }
322
323 // Make sure the trailing zero is included in the constant.
324 target_name.push_back('\0');
325
326 Operation* op = dispatch_op.getOperation();
327 Location loc = op->getLoc();
328 DispatchOp::Adaptor adaptor(operands);
329 SmallVector<Value, 3> callOpOperands;
330 LLVMFuncOp dispatch_func = getOrInsertDispatchFunction(rewriter, op);
331
332 SmallVector<Value, 1> resultPtrs;
333 Value packedArgs =
334 rewriteInsOutsOfDispatchOp(dispatch_op, operands, rewriter, resultPtrs);
335
336 // the first argument is ral_context
337 callOpOperands.push_back(adaptor.ctx());
338 // the second argument is the target name
339 callOpOperands.push_back(loadOrCreateGlobalString(
340 rewriter, symbol_table_, op, target_name.str().drop_back(),
341 target_name.str()));
342 // the third argument is the args for target function
343 callOpOperands.push_back(packedArgs);
344
345 rewriter.create<LLVM::CallOp>(loc, llvm::None,
346 rewriter.getSymbolRefAttr(dispatch_func),
347 callOpOperands);
348
349 SmallVector<Value, 1> results;
350 llvm::transform(resultPtrs, std::back_inserter(results), [&](Value v) {
351 return rewriter.create<LLVM::LoadOp>(loc, v);
352 });
353
354 rewriter.replaceOp(op, results);
355
356 return success();
357 }
358
359 // A rewrite pattern to convert gpu.launch_func operations into corresponding
360 // runtime wrapper calls (modeled by ral.dispatch ops)
361 class ConvertLaunchFuncOpToRalCallPattern
362 : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
363 public:
ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter & type_converter,SymbolTable & symbol_table)364 ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter,
365 SymbolTable& symbol_table)
366 : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
367 symbol_table_(symbol_table) {}
368
369 private:
370 Value generateParamsArray(gpu::LaunchFuncOp launch_op,
371 ArrayRef<Value> operands, OpBuilder& builder) const;
372 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
373 Location loc, OpBuilder& builder) const;
374
375 LogicalResult matchAndRewrite(
376 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
377 ConversionPatternRewriter& rewriter) const override;
378
379 SymbolTable& symbol_table_;
380 };
381
382 // Creates a struct containing all kernel parameters on the stack and returns
383 // an array of type-erased pointers to the fields of the struct. The array can
384 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
385 // The generated code is essentially as follows:
386 //
387 // %struct = alloca(sizeof(struct { Parameters... }))
388 // %array = alloca(NumParameters * sizeof(void *))
389 // for (i : [0, NumParameters))
390 // %fieldPtr = llvm.getelementptr %struct[0, i]
391 // llvm.store parameters[i], %fieldPtr
392 // %elementPtr = llvm.getelementptr %array[i]
393 // llvm.store %fieldPtr, %elementPtr
394 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const395 Value ConvertLaunchFuncOpToRalCallPattern::generateParamsArray(
396 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
397 OpBuilder& builder) const {
398 MLIRContext* ctx = builder.getContext();
399 Type llvm_pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
400 Type llvm_pointer_pointer_type =
401 LLVM::LLVMPointerType::get(llvm_pointer_type);
402 Type llvm_int32_type = IntegerType::get(ctx, 32);
403
404 Location loc = launch_op.getLoc();
405 int num_kernel_operands = launch_op.getNumKernelOperands();
406 auto arguments = getTypeConverter()->promoteOperands(
407 loc, launch_op.getOperands().take_back(num_kernel_operands),
408 operands.take_back(num_kernel_operands), builder);
409 int num_arguments = static_cast<int>(arguments.size());
410 SmallVector<Type, 4> argument_types;
411 argument_types.reserve(num_arguments);
412 for (auto argument : arguments) argument_types.push_back(argument.getType());
413 auto struct_type =
414 LLVM::LLVMStructType::getNewIdentified(ctx, StringRef(), argument_types);
415 Value one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type,
416 builder.getI32IntegerAttr(1));
417 Value struct_ptr = builder.create<LLVM::AllocaOp>(
418 loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
419 Value array_size = builder.create<LLVM::ConstantOp>(
420 loc, llvm_int32_type, builder.getI32IntegerAttr(num_arguments));
421 Value array_ptr = builder.create<LLVM::AllocaOp>(
422 loc, llvm_pointer_pointer_type, array_size, /*alignment=*/0);
423 Value zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type,
424 builder.getI32IntegerAttr(0));
425 for (auto en : llvm::enumerate(arguments)) {
426 Value index = builder.create<LLVM::ConstantOp>(
427 loc, llvm_int32_type, builder.getI32IntegerAttr(en.index()));
428 Value field_ptr = builder.create<LLVM::GEPOp>(
429 loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
430 ArrayRef<Value>{zero, index});
431 builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
432 Value element_ptr = builder.create<LLVM::GEPOp>(
433 loc, llvm_pointer_pointer_type, array_ptr, index);
434 Value casted =
435 builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type, field_ptr);
436 builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
437 }
438 return array_ptr;
439 }
440
441 // Emits LLVM IR to launch a kernel function. Expects the module that contains
442 // the compiled kernel function as a cubin in the `kRalGpuLaunch` attribute.
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const443 LogicalResult ConvertLaunchFuncOpToRalCallPattern::matchAndRewrite(
444 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
445 ConversionPatternRewriter& rewriter) const {
446 if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
447 return rewriter.notifyMatchFailure(
448 launch_op, "Cannot convert with async dependency or result.");
449 }
450
451 // Create an LLVM global with CUBIN extracted from the kernel annotation and
452 // obtain a pointer to the first byte in it.
453 auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
454 launch_op, launch_op.getKernelModuleName());
455 if (!kernel_module) {
456 launch_op.emitOpError() << "cannot find corresponding kernel module.";
457 return failure();
458 }
459
460 auto binary_attr =
461 kernel_module->getAttrOfType<StringAttr>(kGpuBinaryAttrName);
462 if (!binary_attr) {
463 kernel_module.emitOpError()
464 << "missing " << kGpuBinaryAttrName << " attribute";
465 return failure();
466 }
467
468 Operation* op = launch_op.getOperation();
469 Location loc = launch_op.getLoc();
470
471 // Create a global for the module blob.
472 StrT name_buffer(kernel_module.getName());
473 name_buffer.append("_blob");
474
475 Value module_blob = loadOrCreateGlobalString(
476 rewriter, symbol_table_, op, name_buffer.str(), binary_attr.getValue());
477
478 // Make sure the trailing zero is included in the constant.
479 auto kernel_name = launch_op.getKernelName();
480 SmallString<128> kernel_name_buffer(kernel_name);
481 kernel_name_buffer.push_back('\0');
482
483 // Create a global for the kernel name.
484 SmallString<128> kernel_name_global_name_buffer;
485 auto kernel_name_global_name =
486 (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
487 .toStringRef(kernel_name_global_name_buffer);
488 Value kernel_name_global = loadOrCreateGlobalString(
489 rewriter, symbol_table_, op, kernel_name_global_name,
490 kernel_name_buffer.str());
491
492 auto adaptor =
493 gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
494
495 // The Ral Context is the first argument of the surrounding LLVMFunc.
496 Value context_arg =
497 launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
498 auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
499
500 Type llvm_int32_type = IntegerType::get(rewriter.getContext(), 32);
501 Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvm_int32_type,
502 rewriter.getI32IntegerAttr(0));
503 // clang-format off
504 // TODO(disc): we use the default stream a.t.m. Implement a stream assignment
505 // algo in case necessary.
506 SmallVector<Value, 12> newOperands{
507 module_blob, /* gpu module string */
508 kernel_name_global, /* name of the kernel to launch */
509 adaptor.gridSizeX(), adaptor.gridSizeY(), adaptor.gridSizeZ(),
510 adaptor.blockSizeX(), adaptor.blockSizeY(), adaptor.blockSizeZ(),
511 zero, /* sharedMemBytes */
512 zero, /* gpu stream index */
513 kernel_params /* params for the kernel to launch */
514 };
515 // clang-format on
516
517 rewriter.replaceOpWithNewOp<disc_ral::DispatchOp>(
518 launch_op, llvm::None, context_arg, newOperands, kRalGpuLaunch, false,
519 "cpu");
520 return success();
521 }
522
523 class RalToLLVMPass : public RalToLLVMPassBase<RalToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const524 void getDependentDialects(DialectRegistry& registry) const override {
525 registry.insert<LLVM::LLVMDialect>();
526 }
527
528 public:
runOnOperation()529 void runOnOperation() override {
530 ModuleOp m = getOperation();
531 SymbolTable symbol_table(m);
532
533 // Populate type conversions.
534 MLIRContext* ctx = m.getContext();
535 LLVMTypeConverter type_converter(ctx);
536 type_converter.addConversion([&](RalExecutionContextType type) {
537 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
538 });
539
540 // Populate patterns.
541 RewritePatternSet patterns(&getContext());
542 populateStdExpandOpsPatterns(patterns);
543 populateStdToLLVMConversionPatterns(type_converter, patterns);
544 populateDiscRalToLLVMConversionPatterns(&type_converter, &symbol_table,
545 &patterns);
546
547 // Set target.
548 ConversionTarget target(*ctx);
549 target.addLegalDialect<LLVM::LLVMDialect>();
550 target.addIllegalDialect<StandardOpsDialect, gpu::GPUDialect,
551 disc_ral::RalDialect, math::MathDialect>();
552 target.addIllegalOp<UnrealizedConversionCastOp>();
553 // Mark modules as legal.
554 target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
555 // Do not look into gpu modules, only consider host-side.
556 target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
557
558 if (failed(applyFullConversion(m, target, std::move(patterns)))) {
559 signalPassFailure();
560 }
561
562 // Finally, strip the GPU modules, as they are no longer needed.
563 for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
564 op.erase();
565 }
566 }
567 };
568
569 } // namespace
570
populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter * converter,SymbolTable * symbol_table,RewritePatternSet * patterns)571 void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
572 SymbolTable* symbol_table,
573 RewritePatternSet* patterns) {
574 // clang-format off
575 patterns->insert<
576 ConvertLaunchFuncOpToRalCallPattern,
577 DispatchOpToLLVMPattern
578 >(*converter, *symbol_table);
579 // clang-format on
580 }
581
createRalToLLVMPass()582 std::unique_ptr<OperationPass<ModuleOp>> createRalToLLVMPass() {
583 return std::make_unique<RalToLLVMPass>();
584 }
585
586 } // namespace disc_ral
587 } // namespace mlir
588