1 //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/NestedMatcher.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Support/Allocator.h"
16 #include "llvm/Support/raw_ostream.h"
17
18 using namespace mlir;
19
allocator()20 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
21 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
22 return allocator;
23 }
24
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)25 NestedMatch NestedMatch::build(Operation *operation,
26 ArrayRef<NestedMatch> nestedMatches) {
27 auto *result = allocator()->Allocate<NestedMatch>();
28 auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
29 std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
30 new (result) NestedMatch();
31 result->matchedOperation = operation;
32 result->matchedChildren =
33 ArrayRef<NestedMatch>(children, nestedMatches.size());
34 return *result;
35 }
36
allocator()37 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
38 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
39 return allocator;
40 }
41
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)42 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
43 FilterFunctionType filter)
44 : nestedPatterns(), filter(filter), skip(nullptr) {
45 if (!nested.empty()) {
46 auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
47 std::uninitialized_copy(nested.begin(), nested.end(), newNested);
48 nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
49 }
50 }
51
getDepth() const52 unsigned NestedPattern::getDepth() const {
53 if (nestedPatterns.empty()) {
54 return 1;
55 }
56 unsigned depth = 0;
57 for (auto &c : nestedPatterns) {
58 depth = std::max(depth, c.getDepth());
59 }
60 return depth + 1;
61 }
62
63 /// Matches a single operation in the following way:
64 /// 1. checks the kind of operation against the matcher, if different then
65 /// there is no match;
66 /// 2. calls the customizable filter function to refine the single operation
67 /// match with extra semantic constraints;
68 /// 3. if all is good, recursively matches the nested patterns;
69 /// 4. if all nested match then the single operation matches too and is
70 /// appended to the list of matches;
71 /// 5. TODO: Optionally applies actions (lambda), in which case we will want
72 /// to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)73 void NestedPattern::matchOne(Operation *op,
74 SmallVectorImpl<NestedMatch> *matches) {
75 if (skip == op) {
76 return;
77 }
78 // Local custom filter function
79 if (!filter(*op)) {
80 return;
81 }
82
83 if (nestedPatterns.empty()) {
84 SmallVector<NestedMatch, 8> nestedMatches;
85 matches->push_back(NestedMatch::build(op, nestedMatches));
86 return;
87 }
88 // Take a copy of each nested pattern so we can match it.
89 for (auto nestedPattern : nestedPatterns) {
90 SmallVector<NestedMatch, 8> nestedMatches;
91 // Skip elem in the walk immediately following. Without this we would
92 // essentially need to reimplement walk here.
93 nestedPattern.skip = op;
94 nestedPattern.match(op, &nestedMatches);
95 // If we could not match even one of the specified nestedPattern, early exit
96 // as this whole branch is not a match.
97 if (nestedMatches.empty()) {
98 return;
99 }
100 matches->push_back(NestedMatch::build(op, nestedMatches));
101 }
102 }
103
isAffineForOp(Operation & op)104 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
105
isAffineIfOp(Operation & op)106 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
107
108 namespace mlir {
109 namespace matcher {
110
Op(FilterFunctionType filter)111 NestedPattern Op(FilterFunctionType filter) {
112 return NestedPattern({}, filter);
113 }
114
If(NestedPattern child)115 NestedPattern If(NestedPattern child) {
116 return NestedPattern(child, isAffineIfOp);
117 }
If(FilterFunctionType filter,NestedPattern child)118 NestedPattern If(FilterFunctionType filter, NestedPattern child) {
119 return NestedPattern(child, [filter](Operation &op) {
120 return isAffineIfOp(op) && filter(op);
121 });
122 }
If(ArrayRef<NestedPattern> nested)123 NestedPattern If(ArrayRef<NestedPattern> nested) {
124 return NestedPattern(nested, isAffineIfOp);
125 }
If(FilterFunctionType filter,ArrayRef<NestedPattern> nested)126 NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
127 return NestedPattern(nested, [filter](Operation &op) {
128 return isAffineIfOp(op) && filter(op);
129 });
130 }
131
For(NestedPattern child)132 NestedPattern For(NestedPattern child) {
133 return NestedPattern(child, isAffineForOp);
134 }
For(FilterFunctionType filter,NestedPattern child)135 NestedPattern For(FilterFunctionType filter, NestedPattern child) {
136 return NestedPattern(
137 child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
138 }
For(ArrayRef<NestedPattern> nested)139 NestedPattern For(ArrayRef<NestedPattern> nested) {
140 return NestedPattern(nested, isAffineForOp);
141 }
For(FilterFunctionType filter,ArrayRef<NestedPattern> nested)142 NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
143 return NestedPattern(
144 nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
145 }
146
isLoadOrStore(Operation & op)147 bool isLoadOrStore(Operation &op) {
148 return isa<AffineLoadOp, AffineStoreOp>(op);
149 }
150
151 } // end namespace matcher
152 } // end namespace mlir
153