1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/Pass.h"
10 #include "mlir/Pass/PassManager.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12
13 using namespace mlir;
14
15 /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PDLValue value,ArrayAttr constantParams,PatternRewriter & rewriter)16 static LogicalResult customSingleEntityConstraint(PDLValue value,
17 ArrayAttr constantParams,
18 PatternRewriter &rewriter) {
19 Operation *rootOp = value.cast<Operation *>();
20 return success(rootOp->getName().getStringRef() == "test.op");
21 }
customMultiEntityConstraint(ArrayRef<PDLValue> values,ArrayAttr constantParams,PatternRewriter & rewriter)22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
23 ArrayAttr constantParams,
24 PatternRewriter &rewriter) {
25 return customSingleEntityConstraint(values[1], constantParams, rewriter);
26 }
27
28 // Custom creator invoked from PDL.
customCreate(ArrayRef<PDLValue> args,ArrayAttr constantParams,PatternRewriter & rewriter)29 static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
30 PatternRewriter &rewriter) {
31 return rewriter.createOperation(
32 OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
33 }
34
35 /// Custom rewriter invoked from PDL.
customRewriter(Operation * root,ArrayRef<PDLValue> args,ArrayAttr constantParams,PatternRewriter & rewriter)36 static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
37 ArrayAttr constantParams,
38 PatternRewriter &rewriter) {
39 OperationState successOpState(root->getLoc(), "test.success");
40 successOpState.addOperands(args[0].cast<Value>());
41 successOpState.addAttribute("constantParams", constantParams);
42 rewriter.createOperation(successOpState);
43 rewriter.eraseOp(root);
44 }
45
46 namespace {
47 struct TestPDLByteCodePass
48 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
runOnOperation__anon58e0aa0f0111::TestPDLByteCodePass49 void runOnOperation() final {
50 ModuleOp module = getOperation();
51
52 // The test cases are encompassed via two modules, one containing the
53 // patterns and one containing the operations to rewrite.
54 ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
55 ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
56 if (!patternModule || !irModule)
57 return;
58
59 // Process the pattern module.
60 patternModule.getOperation()->remove();
61 PDLPatternModule pdlPattern(patternModule);
62 pdlPattern.registerConstraintFunction("multi_entity_constraint",
63 customMultiEntityConstraint);
64 pdlPattern.registerConstraintFunction("single_entity_constraint",
65 customSingleEntityConstraint);
66 pdlPattern.registerCreateFunction("creator", customCreate);
67 pdlPattern.registerRewriteFunction("rewriter", customRewriter);
68
69 OwningRewritePatternList patternList(std::move(pdlPattern));
70
71 // Invoke the pattern driver with the provided patterns.
72 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
73 std::move(patternList));
74 }
75 };
76 } // end anonymous namespace
77
78 namespace mlir {
79 namespace test {
registerTestPDLByteCodePass()80 void registerTestPDLByteCodePass() {
81 PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
82 "Test PDL ByteCode functionality");
83 }
84 } // namespace test
85 } // namespace mlir
86