1 //===- Generalization.cpp - linalg named ops to generic ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26
27 #define DEBUG_TYPE "linalg-generalization"
28
29 using namespace mlir;
30
31 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
32 // the given `namedOp` does not have a region builder.
createGenericOpFromNamedOp(linalg::LinalgOp namedOp,OpBuilder & builder)33 static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
34 OpBuilder &builder) {
35 auto regionBuilder = namedOp.getRegionBuilder();
36 if (!regionBuilder) {
37 LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
38 return nullptr;
39 }
40
41 SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
42 auto iterators = llvm::to_vector<4>(
43 namedOp.iterator_types().getAsValueRange<StringAttr>());
44 auto resultTypes = namedOp.getOutputTensorTypes();
45 SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
46
47 return builder.create<linalg::GenericOp>(
48 namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
49 namedOp.getInitTensors(), indexingMaps, iterators,
50 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
51 edsc::ScopedContext scope(bodyBuilder, loc);
52 regionBuilder(*bodyBuilder.getBlock());
53 });
54 }
55
56 namespace {
57
58 /// Base class for all linalg generalization patterns. A subclass must provide
59 /// the following method:
60 /// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
61 /// for creating the generic op.
62 // TODO: remove this pattern after migrating all manually-written named ops
63 // into auto-generated ones.
64 template <typename ConcretePattern, typename RootOp>
65 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern__anonec4b71050211::LinalgGeneralizationPattern66 LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
67 PatternBenefit benefit = 1)
68 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
69
matchAndRewrite__anonec4b71050211::LinalgGeneralizationPattern70 LogicalResult matchAndRewrite(RootOp rootOp,
71 PatternRewriter &rewriter) const override {
72 auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
73 if (!linalgOp)
74 return failure();
75 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
76 return failure();
77
78 auto *pattern = static_cast<const ConcretePattern *>(this);
79 linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
80 if (!genericOp)
81 return failure();
82
83 rewriter.replaceOp(rootOp, genericOp.getResults());
84 marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
85 return success();
86 }
87
88 private:
89 linalg::LinalgMarker marker;
90 };
91
92 struct GeneralizeConvOp
93 : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
94 using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
95
96 linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
97 };
98
99 /// Catch-all pattern for converting all named ops with a region builder into
100 /// linalg.generic.
101 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern__anonec4b71050211::LinalgNamedOpGeneralizationPattern102 LinalgNamedOpGeneralizationPattern(MLIRContext *context,
103 linalg::LinalgMarker marker,
104 PatternBenefit benefit = 1)
105 : RewritePattern(benefit, MatchAnyOpTypeTag()),
106 marker(std::move(marker)) {}
107
matchAndRewrite__anonec4b71050211::LinalgNamedOpGeneralizationPattern108 LogicalResult matchAndRewrite(Operation *rootOp,
109 PatternRewriter &rewriter) const override {
110 auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
111 if (!linalgOp)
112 return failure();
113 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
114 return failure();
115
116 // No nothing to do for linalg.generic and linalg.indexed_generic.
117 if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
118 return failure();
119
120 linalg::GenericOp genericOp =
121 createGenericOpFromNamedOp(linalgOp, rewriter);
122 if (!genericOp)
123 return failure();
124
125 rewriter.replaceOp(rootOp, genericOp.getResults());
126 marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
127 return success();
128 }
129
130 private:
131 linalg::LinalgMarker marker;
132 };
133
134 struct LinalgGeneralizationPass
135 : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
136 void runOnFunction() override;
137 };
138
139 } // namespace
140
runOnFunction()141 void LinalgGeneralizationPass::runOnFunction() {
142 FuncOp func = getFunction();
143 OwningRewritePatternList patterns;
144 linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
145 linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
146 applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
147 }
148
createGenericOp(linalg::ConvOp convOp,OpBuilder & builder) const149 linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
150 OpBuilder &builder) const {
151 SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
152 auto iterators =
153 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
154 return builder.create<linalg::GenericOp>(
155 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
156 convOp.getInputBuffers(), convOp.getOutputBuffers(),
157 /*initTensors=*/ValueRange(), indexingMaps, iterators,
158 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
159 Value mul =
160 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
161 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
162 bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
163 });
164 }
165
populateLinalgConvGeneralizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,linalg::LinalgMarker marker)166 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
167 MLIRContext *context, OwningRewritePatternList &patterns,
168 linalg::LinalgMarker marker) {
169 patterns.insert<GeneralizeConvOp>(context, marker);
170 }
171
populateLinalgNamedOpsGeneralizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,linalg::LinalgMarker marker)172 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
173 MLIRContext *context, OwningRewritePatternList &patterns,
174 linalg::LinalgMarker marker) {
175 patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
176 }
177
createLinalgGeneralizationPass()178 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
179 return std::make_unique<LinalgGeneralizationPass>();
180 }
181