1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17
18 using namespace mlir;
19 using namespace mlir::linalg;
20
21 /// Helper function to extract the operand types that are passed to the
22 /// generated CallOp. MemRefTypes have their layout canonicalized since the
23 /// information is not used in signature generation.
24 /// Note that static size information is not modified.
extractOperandTypes(Operation * op)25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
26 SmallVector<Type, 4> result;
27 result.reserve(op->getNumOperands());
28 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
29 auto *ctx = op->getContext();
30 auto numLoops = indexedGenericOp.getNumLoops();
31 result.reserve(op->getNumOperands() + numLoops);
32 result.assign(numLoops, IndexType::get(ctx));
33 }
34 for (auto type : op->getOperandTypes()) {
35 // The underlying descriptor type (e.g. LLVM) does not have layout
36 // information. Canonicalizing the type at the level of std when going into
37 // a library call avoids needing to introduce DialectCastOp.
38 if (auto memrefType = type.dyn_cast<MemRefType>())
39 result.push_back(eraseStridedLayout(memrefType));
40 else
41 result.push_back(type);
42 }
43 return result;
44 }
45
46 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
47 // If the library function does not exist, insert a declaration.
getLibraryCallSymbolRef(Operation * op,PatternRewriter & rewriter)48 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
49 PatternRewriter &rewriter) {
50 auto linalgOp = cast<LinalgOp>(op);
51 auto fnName = linalgOp.getLibraryCallName();
52 if (fnName.empty()) {
53 op->emitWarning("No library call defined for: ") << *op;
54 return {};
55 }
56
57 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
58 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
59 auto module = op->getParentOfType<ModuleOp>();
60 if (module.lookupSymbol(fnName)) {
61 return fnNameAttr;
62 }
63
64 SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
65 assert(op->getNumResults() == 0 &&
66 "Library call for linalg operation can be generated only for ops that "
67 "have void return types");
68 auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
69
70 OpBuilder::InsertionGuard guard(rewriter);
71 // Insert before module terminator.
72 rewriter.setInsertionPoint(module.getBody(),
73 std::prev(module.getBody()->end()));
74 FuncOp funcOp =
75 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
76 // Insert a function attribute that will trigger the emission of the
77 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
78 // a normalized ABI. This interface is added during std to llvm conversion.
79 funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
80 funcOp.setPrivate();
81 return fnNameAttr;
82 }
83
84 static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder & b,Location loc,ValueRange operands)85 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
86 ValueRange operands) {
87 SmallVector<Value, 4> res;
88 res.reserve(operands.size());
89 for (auto op : operands) {
90 auto memrefType = op.getType().dyn_cast<MemRefType>();
91 if (!memrefType) {
92 res.push_back(op);
93 continue;
94 }
95 Value cast =
96 b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
97 res.push_back(cast);
98 }
99 return res;
100 }
101
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const102 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
103 Operation *op, PatternRewriter &rewriter) const {
104 // Only LinalgOp for which there is no specialized pattern go through this.
105 if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
106 return failure();
107
108 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
109 if (!libraryCallName)
110 return failure();
111
112 rewriter.replaceOpWithNewOp<mlir::CallOp>(
113 op, libraryCallName.getValue(), TypeRange(),
114 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
115 op->getOperands()));
116 return success();
117 }
118
matchAndRewrite(CopyOp op,PatternRewriter & rewriter) const119 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
120 CopyOp op, PatternRewriter &rewriter) const {
121 auto inputPerm = op.inputPermutation();
122 if (inputPerm.hasValue() && !inputPerm->isIdentity())
123 return failure();
124 auto outputPerm = op.outputPermutation();
125 if (outputPerm.hasValue() && !outputPerm->isIdentity())
126 return failure();
127
128 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
129 if (!libraryCallName)
130 return failure();
131
132 rewriter.replaceOpWithNewOp<mlir::CallOp>(
133 op, libraryCallName.getValue(), TypeRange(),
134 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
135 op.getOperands()));
136 return success();
137 }
138
matchAndRewrite(CopyOp op,PatternRewriter & rewriter) const139 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
140 CopyOp op, PatternRewriter &rewriter) const {
141 Value in = op.input(), out = op.output();
142
143 // If either inputPerm or outputPerm are non-identities, insert transposes.
144 auto inputPerm = op.inputPermutation();
145 if (inputPerm.hasValue() && !inputPerm->isIdentity())
146 in = rewriter.create<TransposeOp>(op.getLoc(), in,
147 AffineMapAttr::get(*inputPerm));
148 auto outputPerm = op.outputPermutation();
149 if (outputPerm.hasValue() && !outputPerm->isIdentity())
150 out = rewriter.create<TransposeOp>(op.getLoc(), out,
151 AffineMapAttr::get(*outputPerm));
152
153 // If nothing was transposed, fail and let the conversion kick in.
154 if (in == op.input() && out == op.output())
155 return failure();
156
157 rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
158 return success();
159 }
160
161 LogicalResult
matchAndRewrite(IndexedGenericOp op,PatternRewriter & rewriter) const162 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
163 IndexedGenericOp op, PatternRewriter &rewriter) const {
164 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
165 if (!libraryCallName)
166 return failure();
167
168 // TODO: Use induction variables values instead of zeros, when
169 // IndexedGenericOp is tiled.
170 auto zero = rewriter.create<mlir::ConstantOp>(
171 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
172 auto indexedGenericOp = cast<IndexedGenericOp>(op);
173 auto numLoops = indexedGenericOp.getNumLoops();
174 SmallVector<Value, 4> operands;
175 operands.reserve(numLoops + op.getNumOperands());
176 for (unsigned i = 0; i < numLoops; ++i)
177 operands.push_back(zero);
178 for (auto operand : op.getOperands())
179 operands.push_back(operand);
180 rewriter.replaceOpWithNewOp<mlir::CallOp>(
181 op, libraryCallName.getValue(), TypeRange(),
182 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
183 return success();
184 }
185
186 /// Populate the given list with patterns that convert from Linalg to Standard.
populateLinalgToStandardConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)187 void mlir::linalg::populateLinalgToStandardConversionPatterns(
188 OwningRewritePatternList &patterns, MLIRContext *ctx) {
189 // TODO: ConvOp conversion needs to export a descriptor with relevant
190 // attribute values such as kernel striding and dilation.
191 // clang-format off
192 patterns.insert<
193 CopyOpToLibraryCallRewrite,
194 CopyTransposeRewrite,
195 IndexedGenericOpToLibraryCallRewrite>(ctx);
196 patterns.insert<LinalgOpToLibraryCallRewrite>();
197 // clang-format on
198 }
199
200 namespace {
201 struct ConvertLinalgToStandardPass
202 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
203 void runOnOperation() override;
204 };
205 } // namespace
206
runOnOperation()207 void ConvertLinalgToStandardPass::runOnOperation() {
208 auto module = getOperation();
209 ConversionTarget target(getContext());
210 target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
211 target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
212 target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
213 OwningRewritePatternList patterns;
214 populateLinalgToStandardConversionPatterns(patterns, &getContext());
215 if (failed(applyFullConversion(module, target, std::move(patterns))))
216 signalPassFailure();
217 }
218
219 std::unique_ptr<OperationPass<ModuleOp>>
createConvertLinalgToStandardPass()220 mlir::createConvertLinalgToStandardPass() {
221 return std::make_unique<ConvertLinalgToStandardPass>();
222 }
223