• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Helper function to extract the operand types that are passed to the
22 /// generated CallOp. MemRefTypes have their layout canonicalized since the
23 /// information is not used in signature generation.
24 /// Note that static size information is not modified.
extractOperandTypes(Operation * op)25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
26   SmallVector<Type, 4> result;
27   result.reserve(op->getNumOperands());
28   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
29     auto *ctx = op->getContext();
30     auto numLoops = indexedGenericOp.getNumLoops();
31     result.reserve(op->getNumOperands() + numLoops);
32     result.assign(numLoops, IndexType::get(ctx));
33   }
34   for (auto type : op->getOperandTypes()) {
35     // The underlying descriptor type (e.g. LLVM) does not have layout
36     // information. Canonicalizing the type at the level of std when going into
37     // a library call avoids needing to introduce DialectCastOp.
38     if (auto memrefType = type.dyn_cast<MemRefType>())
39       result.push_back(eraseStridedLayout(memrefType));
40     else
41       result.push_back(type);
42   }
43   return result;
44 }
45 
46 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
47 // If the library function does not exist, insert a declaration.
getLibraryCallSymbolRef(Operation * op,PatternRewriter & rewriter)48 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
49                                                  PatternRewriter &rewriter) {
50   auto linalgOp = cast<LinalgOp>(op);
51   auto fnName = linalgOp.getLibraryCallName();
52   if (fnName.empty()) {
53     op->emitWarning("No library call defined for: ") << *op;
54     return {};
55   }
56 
57   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
58   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
59   auto module = op->getParentOfType<ModuleOp>();
60   if (module.lookupSymbol(fnName)) {
61     return fnNameAttr;
62   }
63 
64   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
65   assert(op->getNumResults() == 0 &&
66          "Library call for linalg operation can be generated only for ops that "
67          "have void return types");
68   auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
69 
70   OpBuilder::InsertionGuard guard(rewriter);
71   // Insert before module terminator.
72   rewriter.setInsertionPoint(module.getBody(),
73                              std::prev(module.getBody()->end()));
74   FuncOp funcOp =
75       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
76   // Insert a function attribute that will trigger the emission of the
77   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
78   // a normalized ABI. This interface is added during std to llvm conversion.
79   funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
80   funcOp.setPrivate();
81   return fnNameAttr;
82 }
83 
84 static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder & b,Location loc,ValueRange operands)85 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
86                                       ValueRange operands) {
87   SmallVector<Value, 4> res;
88   res.reserve(operands.size());
89   for (auto op : operands) {
90     auto memrefType = op.getType().dyn_cast<MemRefType>();
91     if (!memrefType) {
92       res.push_back(op);
93       continue;
94     }
95     Value cast =
96         b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
97     res.push_back(cast);
98   }
99   return res;
100 }
101 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const102 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
103     Operation *op, PatternRewriter &rewriter) const {
104   // Only LinalgOp for which there is no specialized pattern go through this.
105   if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
106     return failure();
107 
108   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
109   if (!libraryCallName)
110     return failure();
111 
112   rewriter.replaceOpWithNewOp<mlir::CallOp>(
113       op, libraryCallName.getValue(), TypeRange(),
114       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
115                                             op->getOperands()));
116   return success();
117 }
118 
matchAndRewrite(CopyOp op,PatternRewriter & rewriter) const119 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
120     CopyOp op, PatternRewriter &rewriter) const {
121   auto inputPerm = op.inputPermutation();
122   if (inputPerm.hasValue() && !inputPerm->isIdentity())
123     return failure();
124   auto outputPerm = op.outputPermutation();
125   if (outputPerm.hasValue() && !outputPerm->isIdentity())
126     return failure();
127 
128   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
129   if (!libraryCallName)
130     return failure();
131 
132   rewriter.replaceOpWithNewOp<mlir::CallOp>(
133       op, libraryCallName.getValue(), TypeRange(),
134       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
135                                             op.getOperands()));
136   return success();
137 }
138 
matchAndRewrite(CopyOp op,PatternRewriter & rewriter) const139 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
140     CopyOp op, PatternRewriter &rewriter) const {
141   Value in = op.input(), out = op.output();
142 
143   // If either inputPerm or outputPerm are non-identities, insert transposes.
144   auto inputPerm = op.inputPermutation();
145   if (inputPerm.hasValue() && !inputPerm->isIdentity())
146     in = rewriter.create<TransposeOp>(op.getLoc(), in,
147                                       AffineMapAttr::get(*inputPerm));
148   auto outputPerm = op.outputPermutation();
149   if (outputPerm.hasValue() && !outputPerm->isIdentity())
150     out = rewriter.create<TransposeOp>(op.getLoc(), out,
151                                        AffineMapAttr::get(*outputPerm));
152 
153   // If nothing was transposed, fail and let the conversion kick in.
154   if (in == op.input() && out == op.output())
155     return failure();
156 
157   rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
158   return success();
159 }
160 
161 LogicalResult
matchAndRewrite(IndexedGenericOp op,PatternRewriter & rewriter) const162 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
163     IndexedGenericOp op, PatternRewriter &rewriter) const {
164   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
165   if (!libraryCallName)
166     return failure();
167 
168   // TODO: Use induction variables values instead of zeros, when
169   // IndexedGenericOp is tiled.
170   auto zero = rewriter.create<mlir::ConstantOp>(
171       op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
172   auto indexedGenericOp = cast<IndexedGenericOp>(op);
173   auto numLoops = indexedGenericOp.getNumLoops();
174   SmallVector<Value, 4> operands;
175   operands.reserve(numLoops + op.getNumOperands());
176   for (unsigned i = 0; i < numLoops; ++i)
177     operands.push_back(zero);
178   for (auto operand : op.getOperands())
179     operands.push_back(operand);
180   rewriter.replaceOpWithNewOp<mlir::CallOp>(
181       op, libraryCallName.getValue(), TypeRange(),
182       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
183   return success();
184 }
185 
186 /// Populate the given list with patterns that convert from Linalg to Standard.
populateLinalgToStandardConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)187 void mlir::linalg::populateLinalgToStandardConversionPatterns(
188     OwningRewritePatternList &patterns, MLIRContext *ctx) {
189   // TODO: ConvOp conversion needs to export a descriptor with relevant
190   // attribute values such as kernel striding and dilation.
191   // clang-format off
192   patterns.insert<
193       CopyOpToLibraryCallRewrite,
194       CopyTransposeRewrite,
195       IndexedGenericOpToLibraryCallRewrite>(ctx);
196   patterns.insert<LinalgOpToLibraryCallRewrite>();
197   // clang-format on
198 }
199 
200 namespace {
201 struct ConvertLinalgToStandardPass
202     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
203   void runOnOperation() override;
204 };
205 } // namespace
206 
runOnOperation()207 void ConvertLinalgToStandardPass::runOnOperation() {
208   auto module = getOperation();
209   ConversionTarget target(getContext());
210   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
211   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
212   target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
213   OwningRewritePatternList patterns;
214   populateLinalgToStandardConversionPatterns(patterns, &getContext());
215   if (failed(applyFullConversion(module, target, std::move(patterns))))
216     signalPassFailure();
217 }
218 
219 std::unique_ptr<OperationPass<ModuleOp>>
createConvertLinalgToStandardPass()220 mlir::createConvertLinalgToStandardPass() {
221   return std::make_unique<ConvertLinalgToStandardPass>();
222 }
223