1 //===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes operations before the conversion to SPIR-V
10 // dialect to handle ops that cannot be lowered directly.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
16 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 using namespace mlir;
24
25 namespace {
26 /// Merges subview operation with load/transferRead operation.
27 template <typename OpTy>
28 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
29 public:
30 using OpRewritePattern<OpTy>::OpRewritePattern;
31
32 LogicalResult matchAndRewrite(OpTy loadOp,
33 PatternRewriter &rewriter) const override;
34
35 private:
36 void replaceOp(OpTy loadOp, SubViewOp subViewOp,
37 ArrayRef<Value> sourceIndices,
38 PatternRewriter &rewriter) const;
39 };
40
41 /// Merges subview operation with store/transferWriteOp operation.
42 template <typename OpTy>
43 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
44 public:
45 using OpRewritePattern<OpTy>::OpRewritePattern;
46
47 LogicalResult matchAndRewrite(OpTy storeOp,
48 PatternRewriter &rewriter) const override;
49
50 private:
51 void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
52 ArrayRef<Value> sourceIndices,
53 PatternRewriter &rewriter) const;
54 };
55
56 template <>
replaceOp(LoadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const57 void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
58 SubViewOp subViewOp,
59 ArrayRef<Value> sourceIndices,
60 PatternRewriter &rewriter) const {
61 rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
62 sourceIndices);
63 }
64
65 template <>
replaceOp(vector::TransferReadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const66 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
67 vector::TransferReadOp loadOp, SubViewOp subViewOp,
68 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
69 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
70 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
71 loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr());
72 }
73
74 template <>
replaceOp(StoreOp storeOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const75 void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
76 StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
77 PatternRewriter &rewriter) const {
78 rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
79 subViewOp.source(), sourceIndices);
80 }
81
82 template <>
replaceOp(vector::TransferWriteOp tranferWriteOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const83 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
84 vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
85 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
86 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
87 tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
88 sourceIndices, tranferWriteOp.permutation_map(),
89 tranferWriteOp.maskedAttr());
90 }
91 } // namespace
92
93 //===----------------------------------------------------------------------===//
94 // Utility functions for op legalization.
95 //===----------------------------------------------------------------------===//
96
97 /// Given the 'indices' of an load/store operation where the memref is a result
98 /// of a subview op, returns the indices w.r.t to the source memref of the
99 /// subview op. For example
100 ///
101 /// %0 = ... : memref<12x42xf32>
102 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
103 /// memref<4x4xf32, offset=?, strides=[?, ?]>
104 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
105 ///
106 /// could be folded into
107 ///
108 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
109 /// memref<12x42xf32>
110 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)111 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
112 SubViewOp subViewOp, ValueRange indices,
113 SmallVectorImpl<Value> &sourceIndices) {
114 // TODO: Aborting when the offsets are static. There might be a way to fold
115 // the subview op with load even if the offsets have been canonicalized
116 // away.
117 SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
118 auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
119 auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
120 assert(opRanges.size() == indices.size() &&
121 "expected as many indices as rank of subview op result type");
122
123 // New indices for the load are the current indices * subview_stride +
124 // subview_offset.
125 sourceIndices.resize(indices.size());
126 for (auto index : llvm::enumerate(indices)) {
127 auto offset = *(opOffsets.begin() + index.index());
128 auto stride = *(opStrides.begin() + index.index());
129 auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
130 sourceIndices[index.index()] =
131 rewriter.create<AddIOp>(loc, offset, mul).getResult();
132 }
133 return success();
134 }
135
136 //===----------------------------------------------------------------------===//
137 // Folding SubViewOp and LoadOp/TransferReadOp.
138 //===----------------------------------------------------------------------===//
139
140 template <typename OpTy>
141 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const142 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
143 PatternRewriter &rewriter) const {
144 auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
145 if (!subViewOp) {
146 return failure();
147 }
148 SmallVector<Value, 4> sourceIndices;
149 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
150 loadOp.indices(), sourceIndices)))
151 return failure();
152
153 replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
154 return success();
155 }
156
157 //===----------------------------------------------------------------------===//
158 // Folding SubViewOp and StoreOp/TransferWriteOp.
159 //===----------------------------------------------------------------------===//
160
161 template <typename OpTy>
162 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const163 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
164 PatternRewriter &rewriter) const {
165 auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
166 if (!subViewOp) {
167 return failure();
168 }
169 SmallVector<Value, 4> sourceIndices;
170 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
171 storeOp.indices(), sourceIndices)))
172 return failure();
173
174 replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
175 return success();
176 }
177
178 //===----------------------------------------------------------------------===//
179 // Hook for adding patterns.
180 //===----------------------------------------------------------------------===//
181
populateStdLegalizationPatternsForSPIRVLowering(MLIRContext * context,OwningRewritePatternList & patterns)182 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
183 MLIRContext *context, OwningRewritePatternList &patterns) {
184 patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
185 LoadOpOfSubViewFolder<vector::TransferReadOp>,
186 StoreOpOfSubViewFolder<StoreOp>,
187 StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
188 }
189
190 //===----------------------------------------------------------------------===//
191 // Pass for testing just the legalization patterns.
192 //===----------------------------------------------------------------------===//
193
194 namespace {
195 struct SPIRVLegalization final
196 : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
197 void runOnOperation() override;
198 };
199 } // namespace
200
runOnOperation()201 void SPIRVLegalization::runOnOperation() {
202 OwningRewritePatternList patterns;
203 auto *context = &getContext();
204 populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
205 applyPatternsAndFoldGreedily(getOperation()->getRegions(),
206 std::move(patterns));
207 }
208
createLegalizeStdOpsForSPIRVLoweringPass()209 std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
210 return std::make_unique<SPIRVLegalization>();
211 }
212