• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to rewrite sequential chains of
10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/SPIRV/Passes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
26 /// `spirv::CompositeConstructOp` operation if possible.
27 class RewriteInsertsPass
28     : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
29 public:
30   void runOnOperation() override;
31 
32 private:
33   /// Collects a sequential insertion chain by the given
34   /// `spirv::CompositeInsertOp` operation, if the given operation is the last
35   /// in the chain.
36   LogicalResult
37   collectInsertionChain(spirv::CompositeInsertOp op,
38                         SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
39 };
40 
41 } // anonymous namespace
42 
runOnOperation()43 void RewriteInsertsPass::runOnOperation() {
44   SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
45   getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
46     SmallVector<spirv::CompositeInsertOp, 4> insertions;
47     if (succeeded(collectInsertionChain(op, insertions)))
48       workList.push_back(insertions);
49   });
50 
51   for (const auto &insertions : workList) {
52     auto lastCompositeInsertOp = insertions.back();
53     auto compositeType = lastCompositeInsertOp.getType();
54     auto location = lastCompositeInsertOp.getLoc();
55 
56     SmallVector<Value, 4> operands;
57     // Collect inserted objects.
58     for (auto insertionOp : insertions)
59       operands.push_back(insertionOp.object());
60 
61     OpBuilder builder(lastCompositeInsertOp);
62     auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
63         location, compositeType, operands);
64 
65     lastCompositeInsertOp.replaceAllUsesWith(
66         compositeConstructOp->getResult(0));
67 
68     // Erase ops.
69     for (auto insertOp : llvm::reverse(insertions)) {
70       auto *op = insertOp.getOperation();
71       if (op->use_empty())
72         insertOp.erase();
73     }
74   }
75 }
76 
collectInsertionChain(spirv::CompositeInsertOp op,SmallVectorImpl<spirv::CompositeInsertOp> & insertions)77 LogicalResult RewriteInsertsPass::collectInsertionChain(
78     spirv::CompositeInsertOp op,
79     SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
80   auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
81   // TODO: handle nested composite object.
82   if (indicesArrayAttr.size() == 1) {
83     auto numElements =
84         op.composite().getType().cast<spirv::CompositeType>().getNumElements();
85 
86     auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
87     // Need a last index to collect a sequential chain.
88     if (index + 1 != numElements)
89       return failure();
90 
91     insertions.resize(numElements);
92     while (true) {
93       insertions[index] = op;
94 
95       if (index == 0)
96         return success();
97 
98       op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
99       if (!op)
100         return failure();
101 
102       --index;
103       indicesArrayAttr = op.indices().cast<ArrayAttr>();
104       if ((indicesArrayAttr.size() != 1) ||
105           (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
106         return failure();
107     }
108   }
109   return failure();
110 }
111 
112 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
createRewriteInsertsPass()113 mlir::spirv::createRewriteInsertsPass() {
114   return std::make_unique<RewriteInsertsPass>();
115 }
116