1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15
16 using namespace mlir;
17
18 namespace {
19 /// A pattern that converts the region arguments in a single-region OpenMP
20 /// operation to the LLVM dialect. The body of the region is not modified and is
21 /// expected to either be processed by the conversion infrastructure or already
22 /// contain ops compatible with LLVM dialect types.
23 template <typename OpType>
24 struct RegionOpConversion : public ConvertToLLVMPattern {
RegionOpConversion__anon0b64e0de0111::RegionOpConversion25 explicit RegionOpConversion(MLIRContext *context,
26 LLVMTypeConverter &typeConverter)
27 : ConvertToLLVMPattern(OpType::getOperationName(), context,
28 typeConverter) {}
29
30 LogicalResult
matchAndRewrite__anon0b64e0de0111::RegionOpConversion31 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
32 ConversionPatternRewriter &rewriter) const override {
33 auto curOp = cast<OpType>(op);
34 auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
35 curOp.getAttrs());
36 rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
37 newOp.region().end());
38 if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
39 return failure();
40
41 rewriter.eraseOp(op);
42 return success();
43 }
44 };
45 } // namespace
46
populateOpenMPToLLVMConversionPatterns(MLIRContext * context,LLVMTypeConverter & converter,OwningRewritePatternList & patterns)47 void mlir::populateOpenMPToLLVMConversionPatterns(
48 MLIRContext *context, LLVMTypeConverter &converter,
49 OwningRewritePatternList &patterns) {
50 patterns.insert<RegionOpConversion<omp::ParallelOp>,
51 RegionOpConversion<omp::WsLoopOp>>(context, converter);
52 }
53
54 namespace {
55 struct ConvertOpenMPToLLVMPass
56 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
57 void runOnOperation() override;
58 };
59 } // namespace
60
runOnOperation()61 void ConvertOpenMPToLLVMPass::runOnOperation() {
62 auto module = getOperation();
63 MLIRContext *context = &getContext();
64
65 // Convert to OpenMP operations with LLVM IR dialect
66 OwningRewritePatternList patterns;
67 LLVMTypeConverter converter(&getContext());
68 populateStdToLLVMConversionPatterns(converter, patterns);
69 populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
70
71 LLVMConversionTarget target(getContext());
72 target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
73 [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
74 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
75 omp::BarrierOp, omp::TaskwaitOp>();
76 if (failed(applyPartialConversion(module, target, std::move(patterns))))
77 signalPassFailure();
78 }
79
createConvertOpenMPToLLVMPass()80 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
81 return std::make_unique<ConvertOpenMPToLLVMPass>();
82 }
83