• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes operations before the conversion to SPIR-V
10 // dialect to handle ops that cannot be lowered directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
16 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 /// Merges subview operation with load/transferRead operation.
27 template <typename OpTy>
28 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
29 public:
30   using OpRewritePattern<OpTy>::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(OpTy loadOp,
33                                 PatternRewriter &rewriter) const override;
34 
35 private:
36   void replaceOp(OpTy loadOp, SubViewOp subViewOp,
37                  ArrayRef<Value> sourceIndices,
38                  PatternRewriter &rewriter) const;
39 };
40 
41 /// Merges subview operation with store/transferWriteOp operation.
42 template <typename OpTy>
43 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
44 public:
45   using OpRewritePattern<OpTy>::OpRewritePattern;
46 
47   LogicalResult matchAndRewrite(OpTy storeOp,
48                                 PatternRewriter &rewriter) const override;
49 
50 private:
51   void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
52                  ArrayRef<Value> sourceIndices,
53                  PatternRewriter &rewriter) const;
54 };
55 
56 template <>
replaceOp(LoadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const57 void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
58                                               SubViewOp subViewOp,
59                                               ArrayRef<Value> sourceIndices,
60                                               PatternRewriter &rewriter) const {
61   rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
62                                       sourceIndices);
63 }
64 
65 template <>
replaceOp(vector::TransferReadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const66 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
67     vector::TransferReadOp loadOp, SubViewOp subViewOp,
68     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
69   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
70       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
71       loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr());
72 }
73 
74 template <>
replaceOp(StoreOp storeOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const75 void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
76     StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
77     PatternRewriter &rewriter) const {
78   rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
79                                        subViewOp.source(), sourceIndices);
80 }
81 
82 template <>
replaceOp(vector::TransferWriteOp tranferWriteOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const83 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
84     vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
85     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
86   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
87       tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
88       sourceIndices, tranferWriteOp.permutation_map(),
89       tranferWriteOp.maskedAttr());
90 }
91 } // namespace
92 
93 //===----------------------------------------------------------------------===//
94 // Utility functions for op legalization.
95 //===----------------------------------------------------------------------===//
96 
97 /// Given the 'indices' of an load/store operation where the memref is a result
98 /// of a subview op, returns the indices w.r.t to the source memref of the
99 /// subview op. For example
100 ///
101 /// %0 = ... : memref<12x42xf32>
102 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
103 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
104 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
105 ///
106 /// could be folded into
107 ///
108 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
109 ///          memref<12x42xf32>
110 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)111 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
112                      SubViewOp subViewOp, ValueRange indices,
113                      SmallVectorImpl<Value> &sourceIndices) {
114   // TODO: Aborting when the offsets are static. There might be a way to fold
115   // the subview op with load even if the offsets have been canonicalized
116   // away.
117   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
118   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
119   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
120   assert(opRanges.size() == indices.size() &&
121          "expected as many indices as rank of subview op result type");
122 
123   // New indices for the load are the current indices * subview_stride +
124   // subview_offset.
125   sourceIndices.resize(indices.size());
126   for (auto index : llvm::enumerate(indices)) {
127     auto offset = *(opOffsets.begin() + index.index());
128     auto stride = *(opStrides.begin() + index.index());
129     auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
130     sourceIndices[index.index()] =
131         rewriter.create<AddIOp>(loc, offset, mul).getResult();
132   }
133   return success();
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Folding SubViewOp and LoadOp/TransferReadOp.
138 //===----------------------------------------------------------------------===//
139 
140 template <typename OpTy>
141 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const142 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
143                                              PatternRewriter &rewriter) const {
144   auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
145   if (!subViewOp) {
146     return failure();
147   }
148   SmallVector<Value, 4> sourceIndices;
149   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
150                                   loadOp.indices(), sourceIndices)))
151     return failure();
152 
153   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
154   return success();
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // Folding SubViewOp and StoreOp/TransferWriteOp.
159 //===----------------------------------------------------------------------===//
160 
161 template <typename OpTy>
162 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const163 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
164                                               PatternRewriter &rewriter) const {
165   auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
166   if (!subViewOp) {
167     return failure();
168   }
169   SmallVector<Value, 4> sourceIndices;
170   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
171                                   storeOp.indices(), sourceIndices)))
172     return failure();
173 
174   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
175   return success();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Hook for adding patterns.
180 //===----------------------------------------------------------------------===//
181 
populateStdLegalizationPatternsForSPIRVLowering(MLIRContext * context,OwningRewritePatternList & patterns)182 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
183     MLIRContext *context, OwningRewritePatternList &patterns) {
184   patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
185                   LoadOpOfSubViewFolder<vector::TransferReadOp>,
186                   StoreOpOfSubViewFolder<StoreOp>,
187                   StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Pass for testing just the legalization patterns.
192 //===----------------------------------------------------------------------===//
193 
194 namespace {
195 struct SPIRVLegalization final
196     : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
197   void runOnOperation() override;
198 };
199 } // namespace
200 
runOnOperation()201 void SPIRVLegalization::runOnOperation() {
202   OwningRewritePatternList patterns;
203   auto *context = &getContext();
204   populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
205   applyPatternsAndFoldGreedily(getOperation()->getRegions(),
206                                std::move(patterns));
207 }
208 
createLegalizeStdOpsForSPIRVLoweringPass()209 std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
210   return std::make_unique<SPIRVLegalization>();
211 }
212