• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
14 
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)29 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
30   if (!boolAttr)
31     return llvm::None;
32 
33   auto type = boolAttr.getType();
34   if (type.isInteger(1)) {
35     auto attr = boolAttr.cast<BoolAttr>();
36     return attr.getValue();
37   }
38   if (auto vecType = type.cast<VectorType>()) {
39     if (vecType.getElementType().isInteger(1))
40       if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
41         return attr.getSplatValue<bool>();
42   }
43   return llvm::None;
44 }
45 
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)48 static Attribute extractCompositeElement(Attribute composite,
49                                          ArrayRef<unsigned> indices) {
50   // Check that given composite is a constant.
51   if (!composite)
52     return {};
53   // Return composite itself if we reach the end of the index chain.
54   if (indices.empty())
55     return composite;
56 
57   if (auto vector = composite.dyn_cast<ElementsAttr>()) {
58     assert(indices.size() == 1 && "must have exactly one index for a vector");
59     return vector.getValue({indices[0]});
60   }
61 
62   if (auto array = composite.dyn_cast<ArrayAttr>()) {
63     assert(!indices.empty() && "must have at least one index for an array");
64     return extractCompositeElement(array.getValue()[indices[0]],
65                                    indices.drop_front());
66   }
67 
68   return {};
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
74 
75 namespace {
76 #include "SPIRVCanonicalization.inc"
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // spv.AccessChainOp
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88     : public OpRewritePattern<spirv::AccessChainOp> {
89   using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
90 
matchAndRewrite__anonfa547f020211::CombineChainedAccessChain91   LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
92                                 PatternRewriter &rewriter) const override {
93     auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
94         accessChainOp.base_ptr().getDefiningOp());
95 
96     if (!parentAccessChainOp) {
97       return failure();
98     }
99 
100     // Combine indices.
101     SmallVector<Value, 4> indices(parentAccessChainOp.indices());
102     indices.append(accessChainOp.indices().begin(),
103                    accessChainOp.indices().end());
104 
105     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
106         accessChainOp, parentAccessChainOp.base_ptr(), indices);
107 
108     return success();
109   }
110 };
111 } // end anonymous namespace
112 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114     OwningRewritePatternList &results, MLIRContext *context) {
115   results.insert<CombineChainedAccessChain>(context);
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // spv.BitcastOp
120 //===----------------------------------------------------------------------===//
121 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)122 void spirv::BitcastOp::getCanonicalizationPatterns(
123     OwningRewritePatternList &results, MLIRContext *context) {
124   results.insert<ConvertChainedBitcast>(context);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
130 
fold(ArrayRef<Attribute> operands)131 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
132   assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
133   auto indexVector =
134       llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
135         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
136       }));
137   return extractCompositeElement(operands[0], indexVector);
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // spv.constant
142 //===----------------------------------------------------------------------===//
143 
fold(ArrayRef<Attribute> operands)144 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
145   assert(operands.empty() && "spv.constant has no operands");
146   return value();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // spv.IAdd
151 //===----------------------------------------------------------------------===//
152 
fold(ArrayRef<Attribute> operands)153 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
154   assert(operands.size() == 2 && "spv.IAdd expects two operands");
155   // x + 0 = x
156   if (matchPattern(operand2(), m_Zero()))
157     return operand1();
158 
159   // According to the SPIR-V spec:
160   //
161   // The resulting value will equal the low-order N bits of the correct result
162   // R, where N is the component width and R is computed with enough precision
163   // to avoid overflow and underflow.
164   return constFoldBinaryOp<IntegerAttr>(operands,
165                                         [](APInt a, APInt b) { return a + b; });
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // spv.IMul
170 //===----------------------------------------------------------------------===//
171 
fold(ArrayRef<Attribute> operands)172 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
173   assert(operands.size() == 2 && "spv.IMul expects two operands");
174   // x * 0 == 0
175   if (matchPattern(operand2(), m_Zero()))
176     return operand2();
177   // x * 1 = x
178   if (matchPattern(operand2(), m_One()))
179     return operand1();
180 
181   // According to the SPIR-V spec:
182   //
183   // The resulting value will equal the low-order N bits of the correct result
184   // R, where N is the component width and R is computed with enough precision
185   // to avoid overflow and underflow.
186   return constFoldBinaryOp<IntegerAttr>(operands,
187                                         [](APInt a, APInt b) { return a * b; });
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // spv.ISub
192 //===----------------------------------------------------------------------===//
193 
fold(ArrayRef<Attribute> operands)194 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
195   // x - x = 0
196   if (operand1() == operand2())
197     return Builder(getContext()).getIntegerAttr(getType(), 0);
198 
199   // According to the SPIR-V spec:
200   //
201   // The resulting value will equal the low-order N bits of the correct result
202   // R, where N is the component width and R is computed with enough precision
203   // to avoid overflow and underflow.
204   return constFoldBinaryOp<IntegerAttr>(operands,
205                                         [](APInt a, APInt b) { return a - b; });
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // spv.LogicalAnd
210 //===----------------------------------------------------------------------===//
211 
fold(ArrayRef<Attribute> operands)212 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
213   assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
214 
215   if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
216     // x && true = x
217     if (rhs.getValue())
218       return operand1();
219 
220     // x && false = false
221     if (!rhs.getValue())
222       return operands.back();
223   }
224 
225   return Attribute();
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // spv.LogicalNot
230 //===----------------------------------------------------------------------===//
231 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233     OwningRewritePatternList &results, MLIRContext *context) {
234   results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
235                  ConvertLogicalNotOfLogicalEqual,
236                  ConvertLogicalNotOfLogicalNotEqual>(context);
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // spv.LogicalOr
241 //===----------------------------------------------------------------------===//
242 
fold(ArrayRef<Attribute> operands)243 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
244   assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
245 
246   if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
247     if (rhs.getValue())
248       // x || true = true
249       return operands.back();
250 
251     // x || false = x
252     if (!rhs.getValue())
253       return operand1();
254   }
255 
256   return Attribute();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spv.selection
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 // Blocks from the given `spv.selection` operation must satisfy the following
265 // layout:
266 //
267 //       +-----------------------------------------------+
268 //       | header block                                  |
269 //       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
270 //       +-----------------------------------------------+
271 //                            /   \
272 //                             ...
273 //
274 //
275 //   +------------------------+    +------------------------+
276 //   | case #0                |    | case #1                |
277 //   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
278 //   | spv.Branch ^merge      |    | spv.Branch ^merge      |
279 //   +------------------------+    +------------------------+
280 //
281 //
282 //                             ...
283 //                            \   /
284 //                              v
285 //                       +-------------+
286 //                       | merge block |
287 //                       +-------------+
288 //
289 struct ConvertSelectionOpToSelect
290     : public OpRewritePattern<spirv::SelectionOp> {
291   using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
292 
matchAndRewrite__anonfa547f020711::ConvertSelectionOpToSelect293   LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
294                                 PatternRewriter &rewriter) const override {
295     auto *op = selectionOp.getOperation();
296     auto &body = op->getRegion(0);
297     // Verifier allows an empty region for `spv.selection`.
298     if (body.empty()) {
299       return failure();
300     }
301 
302     // Check that region consists of 4 blocks:
303     // header block, `true` block, `false` block and merge block.
304     if (std::distance(body.begin(), body.end()) != 4) {
305       return failure();
306     }
307 
308     auto *headerBlock = selectionOp.getHeaderBlock();
309     if (!onlyContainsBranchConditionalOp(headerBlock)) {
310       return failure();
311     }
312 
313     auto brConditionalOp =
314         cast<spirv::BranchConditionalOp>(headerBlock->front());
315 
316     auto *trueBlock = brConditionalOp.getSuccessor(0);
317     auto *falseBlock = brConditionalOp.getSuccessor(1);
318     auto *mergeBlock = selectionOp.getMergeBlock();
319 
320     if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
321       return failure();
322 
323     auto trueValue = getSrcValue(trueBlock);
324     auto falseValue = getSrcValue(falseBlock);
325     auto ptrValue = getDstPtr(trueBlock);
326     auto storeOpAttributes =
327         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
328 
329     auto selectOp = rewriter.create<spirv::SelectOp>(
330         selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
331         trueValue, falseValue);
332     rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
333                                     selectOp.getResult(), storeOpAttributes);
334 
335     // `spv.selection` is not needed anymore.
336     rewriter.eraseOp(op);
337     return success();
338   }
339 
340 private:
341   // Checks that given blocks follow the following rules:
342   // 1. Each conditional block consists of two operations, the first operation
343   //    is a `spv.Store` and the last operation is a `spv.Branch`.
344   // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
345   // 3. A control flow goes into the given merge block from the given
346   //    conditional blocks.
347   LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
348                                          Block *mergeBlock) const;
349 
onlyContainsBranchConditionalOp__anonfa547f020711::ConvertSelectionOpToSelect350   bool onlyContainsBranchConditionalOp(Block *block) const {
351     return std::next(block->begin()) == block->end() &&
352            isa<spirv::BranchConditionalOp>(block->front());
353   }
354 
isSameAttrList__anonfa547f020711::ConvertSelectionOpToSelect355   bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
356     return lhs->getAttrDictionary() == rhs->getAttrDictionary();
357   }
358 
359 
360   // Returns a source value for the given block.
getSrcValue__anonfa547f020711::ConvertSelectionOpToSelect361   Value getSrcValue(Block *block) const {
362     auto storeOp = cast<spirv::StoreOp>(block->front());
363     return storeOp.value();
364   }
365 
366   // Returns a destination value for the given block.
getDstPtr__anonfa547f020711::ConvertSelectionOpToSelect367   Value getDstPtr(Block *block) const {
368     auto storeOp = cast<spirv::StoreOp>(block->front());
369     return storeOp.ptr();
370   }
371 };
372 
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const373 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
374     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
375   // Each block must consists of 2 operations.
376   if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
377       (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
378     return failure();
379   }
380 
381   auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
382   auto trueBrBranchOp =
383       dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
384   auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
385   auto falseBrBranchOp =
386       dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
387 
388   if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
389       !falseBrBranchOp) {
390     return failure();
391   }
392 
393   // Checks that given type is valid for `spv.SelectOp`.
394   // According to SPIR-V spec:
395   // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
396   // Starting with version 1.4, Result Type can additionally be a composite type
397   // other than a vector."
398   bool isScalarOrVector = trueBrStoreOp.value()
399                               .getType()
400                               .cast<spirv::SPIRVType>()
401                               .isScalarOrVector();
402 
403   // Check that each `spv.Store` uses the same pointer, memory access
404   // attributes and a valid type of the value.
405   if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
406       !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
407     return failure();
408   }
409 
410   if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
411       (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
412     return failure();
413   }
414 
415   return success();
416 }
417 } // end anonymous namespace
418 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)419 void spirv::SelectionOp::getCanonicalizationPatterns(
420     OwningRewritePatternList &results, MLIRContext *context) {
421   results.insert<ConvertSelectionOpToSelect>(context);
422 }
423