1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
14
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20
21 using namespace mlir;
22
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
26
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)29 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
30 if (!boolAttr)
31 return llvm::None;
32
33 auto type = boolAttr.getType();
34 if (type.isInteger(1)) {
35 auto attr = boolAttr.cast<BoolAttr>();
36 return attr.getValue();
37 }
38 if (auto vecType = type.cast<VectorType>()) {
39 if (vecType.getElementType().isInteger(1))
40 if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
41 return attr.getSplatValue<bool>();
42 }
43 return llvm::None;
44 }
45
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)48 static Attribute extractCompositeElement(Attribute composite,
49 ArrayRef<unsigned> indices) {
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = composite.dyn_cast<ElementsAttr>()) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValue({indices[0]});
60 }
61
62 if (auto array = composite.dyn_cast<ArrayAttr>()) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69 }
70
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
74
75 namespace {
76 #include "SPIRVCanonicalization.inc"
77 }
78
79 //===----------------------------------------------------------------------===//
80 // spv.AccessChainOp
81 //===----------------------------------------------------------------------===//
82
83 namespace {
84
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88 : public OpRewritePattern<spirv::AccessChainOp> {
89 using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
90
matchAndRewrite__anonfa547f020211::CombineChainedAccessChain91 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
92 PatternRewriter &rewriter) const override {
93 auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
94 accessChainOp.base_ptr().getDefiningOp());
95
96 if (!parentAccessChainOp) {
97 return failure();
98 }
99
100 // Combine indices.
101 SmallVector<Value, 4> indices(parentAccessChainOp.indices());
102 indices.append(accessChainOp.indices().begin(),
103 accessChainOp.indices().end());
104
105 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
106 accessChainOp, parentAccessChainOp.base_ptr(), indices);
107
108 return success();
109 }
110 };
111 } // end anonymous namespace
112
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114 OwningRewritePatternList &results, MLIRContext *context) {
115 results.insert<CombineChainedAccessChain>(context);
116 }
117
118 //===----------------------------------------------------------------------===//
119 // spv.BitcastOp
120 //===----------------------------------------------------------------------===//
121
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)122 void spirv::BitcastOp::getCanonicalizationPatterns(
123 OwningRewritePatternList &results, MLIRContext *context) {
124 results.insert<ConvertChainedBitcast>(context);
125 }
126
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
130
fold(ArrayRef<Attribute> operands)131 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
132 assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
133 auto indexVector =
134 llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
135 return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
136 }));
137 return extractCompositeElement(operands[0], indexVector);
138 }
139
140 //===----------------------------------------------------------------------===//
141 // spv.constant
142 //===----------------------------------------------------------------------===//
143
fold(ArrayRef<Attribute> operands)144 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
145 assert(operands.empty() && "spv.constant has no operands");
146 return value();
147 }
148
149 //===----------------------------------------------------------------------===//
150 // spv.IAdd
151 //===----------------------------------------------------------------------===//
152
fold(ArrayRef<Attribute> operands)153 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
154 assert(operands.size() == 2 && "spv.IAdd expects two operands");
155 // x + 0 = x
156 if (matchPattern(operand2(), m_Zero()))
157 return operand1();
158
159 // According to the SPIR-V spec:
160 //
161 // The resulting value will equal the low-order N bits of the correct result
162 // R, where N is the component width and R is computed with enough precision
163 // to avoid overflow and underflow.
164 return constFoldBinaryOp<IntegerAttr>(operands,
165 [](APInt a, APInt b) { return a + b; });
166 }
167
168 //===----------------------------------------------------------------------===//
169 // spv.IMul
170 //===----------------------------------------------------------------------===//
171
fold(ArrayRef<Attribute> operands)172 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
173 assert(operands.size() == 2 && "spv.IMul expects two operands");
174 // x * 0 == 0
175 if (matchPattern(operand2(), m_Zero()))
176 return operand2();
177 // x * 1 = x
178 if (matchPattern(operand2(), m_One()))
179 return operand1();
180
181 // According to the SPIR-V spec:
182 //
183 // The resulting value will equal the low-order N bits of the correct result
184 // R, where N is the component width and R is computed with enough precision
185 // to avoid overflow and underflow.
186 return constFoldBinaryOp<IntegerAttr>(operands,
187 [](APInt a, APInt b) { return a * b; });
188 }
189
190 //===----------------------------------------------------------------------===//
191 // spv.ISub
192 //===----------------------------------------------------------------------===//
193
fold(ArrayRef<Attribute> operands)194 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
195 // x - x = 0
196 if (operand1() == operand2())
197 return Builder(getContext()).getIntegerAttr(getType(), 0);
198
199 // According to the SPIR-V spec:
200 //
201 // The resulting value will equal the low-order N bits of the correct result
202 // R, where N is the component width and R is computed with enough precision
203 // to avoid overflow and underflow.
204 return constFoldBinaryOp<IntegerAttr>(operands,
205 [](APInt a, APInt b) { return a - b; });
206 }
207
208 //===----------------------------------------------------------------------===//
209 // spv.LogicalAnd
210 //===----------------------------------------------------------------------===//
211
fold(ArrayRef<Attribute> operands)212 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
213 assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
214
215 if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
216 // x && true = x
217 if (rhs.getValue())
218 return operand1();
219
220 // x && false = false
221 if (!rhs.getValue())
222 return operands.back();
223 }
224
225 return Attribute();
226 }
227
228 //===----------------------------------------------------------------------===//
229 // spv.LogicalNot
230 //===----------------------------------------------------------------------===//
231
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233 OwningRewritePatternList &results, MLIRContext *context) {
234 results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
235 ConvertLogicalNotOfLogicalEqual,
236 ConvertLogicalNotOfLogicalNotEqual>(context);
237 }
238
239 //===----------------------------------------------------------------------===//
240 // spv.LogicalOr
241 //===----------------------------------------------------------------------===//
242
fold(ArrayRef<Attribute> operands)243 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
244 assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
245
246 if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
247 if (rhs.getValue())
248 // x || true = true
249 return operands.back();
250
251 // x || false = x
252 if (!rhs.getValue())
253 return operand1();
254 }
255
256 return Attribute();
257 }
258
259 //===----------------------------------------------------------------------===//
260 // spv.selection
261 //===----------------------------------------------------------------------===//
262
263 namespace {
264 // Blocks from the given `spv.selection` operation must satisfy the following
265 // layout:
266 //
267 // +-----------------------------------------------+
268 // | header block |
269 // | spv.BranchConditionalOp %cond, ^case0, ^case1 |
270 // +-----------------------------------------------+
271 // / \
272 // ...
273 //
274 //
275 // +------------------------+ +------------------------+
276 // | case #0 | | case #1 |
277 // | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
278 // | spv.Branch ^merge | | spv.Branch ^merge |
279 // +------------------------+ +------------------------+
280 //
281 //
282 // ...
283 // \ /
284 // v
285 // +-------------+
286 // | merge block |
287 // +-------------+
288 //
289 struct ConvertSelectionOpToSelect
290 : public OpRewritePattern<spirv::SelectionOp> {
291 using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
292
matchAndRewrite__anonfa547f020711::ConvertSelectionOpToSelect293 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
294 PatternRewriter &rewriter) const override {
295 auto *op = selectionOp.getOperation();
296 auto &body = op->getRegion(0);
297 // Verifier allows an empty region for `spv.selection`.
298 if (body.empty()) {
299 return failure();
300 }
301
302 // Check that region consists of 4 blocks:
303 // header block, `true` block, `false` block and merge block.
304 if (std::distance(body.begin(), body.end()) != 4) {
305 return failure();
306 }
307
308 auto *headerBlock = selectionOp.getHeaderBlock();
309 if (!onlyContainsBranchConditionalOp(headerBlock)) {
310 return failure();
311 }
312
313 auto brConditionalOp =
314 cast<spirv::BranchConditionalOp>(headerBlock->front());
315
316 auto *trueBlock = brConditionalOp.getSuccessor(0);
317 auto *falseBlock = brConditionalOp.getSuccessor(1);
318 auto *mergeBlock = selectionOp.getMergeBlock();
319
320 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
321 return failure();
322
323 auto trueValue = getSrcValue(trueBlock);
324 auto falseValue = getSrcValue(falseBlock);
325 auto ptrValue = getDstPtr(trueBlock);
326 auto storeOpAttributes =
327 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
328
329 auto selectOp = rewriter.create<spirv::SelectOp>(
330 selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
331 trueValue, falseValue);
332 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
333 selectOp.getResult(), storeOpAttributes);
334
335 // `spv.selection` is not needed anymore.
336 rewriter.eraseOp(op);
337 return success();
338 }
339
340 private:
341 // Checks that given blocks follow the following rules:
342 // 1. Each conditional block consists of two operations, the first operation
343 // is a `spv.Store` and the last operation is a `spv.Branch`.
344 // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
345 // 3. A control flow goes into the given merge block from the given
346 // conditional blocks.
347 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
348 Block *mergeBlock) const;
349
onlyContainsBranchConditionalOp__anonfa547f020711::ConvertSelectionOpToSelect350 bool onlyContainsBranchConditionalOp(Block *block) const {
351 return std::next(block->begin()) == block->end() &&
352 isa<spirv::BranchConditionalOp>(block->front());
353 }
354
isSameAttrList__anonfa547f020711::ConvertSelectionOpToSelect355 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
356 return lhs->getAttrDictionary() == rhs->getAttrDictionary();
357 }
358
359
360 // Returns a source value for the given block.
getSrcValue__anonfa547f020711::ConvertSelectionOpToSelect361 Value getSrcValue(Block *block) const {
362 auto storeOp = cast<spirv::StoreOp>(block->front());
363 return storeOp.value();
364 }
365
366 // Returns a destination value for the given block.
getDstPtr__anonfa547f020711::ConvertSelectionOpToSelect367 Value getDstPtr(Block *block) const {
368 auto storeOp = cast<spirv::StoreOp>(block->front());
369 return storeOp.ptr();
370 }
371 };
372
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const373 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
374 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
375 // Each block must consists of 2 operations.
376 if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
377 (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
378 return failure();
379 }
380
381 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
382 auto trueBrBranchOp =
383 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
384 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
385 auto falseBrBranchOp =
386 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
387
388 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
389 !falseBrBranchOp) {
390 return failure();
391 }
392
393 // Checks that given type is valid for `spv.SelectOp`.
394 // According to SPIR-V spec:
395 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
396 // Starting with version 1.4, Result Type can additionally be a composite type
397 // other than a vector."
398 bool isScalarOrVector = trueBrStoreOp.value()
399 .getType()
400 .cast<spirv::SPIRVType>()
401 .isScalarOrVector();
402
403 // Check that each `spv.Store` uses the same pointer, memory access
404 // attributes and a valid type of the value.
405 if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
406 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
407 return failure();
408 }
409
410 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
411 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
412 return failure();
413 }
414
415 return success();
416 }
417 } // end anonymous namespace
418
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)419 void spirv::SelectionOp::getCanonicalizationPatterns(
420 OwningRewritePatternList &results, MLIRContext *context) {
421 results.insert<ConvertSelectionOpToSelect>(context);
422 }
423