1 //===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the conversion patterns from SCF ops to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/IR/BuiltinOps.h"
18
19 using namespace mlir;
20
21 namespace mlir {
22 struct ScfToSPIRVContextImpl {
23 // Map between the spirv region control flow operation (spv.loop or
24 // spv.selection) to the VariableOp created to store the region results. The
25 // order of the VariableOp matches the order of the results.
26 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
27 };
28 } // namespace mlir
29
30 /// We use ScfToSPIRVContext to store information about the lowering of the scf
31 /// region that need to be used later on. When we lower scf.for/scf.if we create
32 /// VariableOp to store the results. We need to keep track of the VariableOp
33 /// created as we need to insert stores into them when lowering Yield. Those
34 /// StoreOp cannot be created earlier as they may use a different type than
35 /// yield operands.
ScfToSPIRVContext()36 ScfToSPIRVContext::ScfToSPIRVContext() {
37 impl = std::make_unique<ScfToSPIRVContextImpl>();
38 }
39 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
40
41 namespace {
42 /// Common class for all vector to GPU patterns.
43 template <typename OpTy>
44 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
45 public:
SCFToSPIRVPattern(MLIRContext * context,SPIRVTypeConverter & converter,ScfToSPIRVContextImpl * scfToSPIRVContext)46 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
47 ScfToSPIRVContextImpl *scfToSPIRVContext)
48 : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
49 scfToSPIRVContext(scfToSPIRVContext) {}
50
51 protected:
52 ScfToSPIRVContextImpl *scfToSPIRVContext;
53 };
54
55 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
56 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
57 public:
58 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
59
60 LogicalResult
61 matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
62 ConversionPatternRewriter &rewriter) const override;
63 };
64
65 /// Pattern to convert a scf::IfOp within kernel functions into
66 /// spirv::SelectionOp.
67 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
68 public:
69 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
70
71 LogicalResult
72 matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
73 ConversionPatternRewriter &rewriter) const override;
74 };
75
76 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
77 public:
78 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
79
80 LogicalResult
81 matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
82 ConversionPatternRewriter &rewriter) const override;
83 };
84 } // namespace
85
86 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
87 /// We create VariableOp to handle the results value of the control flow region.
88 /// spv.loop/spv.selection currently don't yield value. Right after the loop
89 /// we load the value from the allocation and use it as the SCF op result.
90 template <typename ScfOp, typename OpTy>
replaceSCFOutputValue(ScfOp scfOp,OpTy newOp,SPIRVTypeConverter & typeConverter,ConversionPatternRewriter & rewriter,ScfToSPIRVContextImpl * scfToSPIRVContext,ArrayRef<Type> returnTypes)91 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
92 SPIRVTypeConverter &typeConverter,
93 ConversionPatternRewriter &rewriter,
94 ScfToSPIRVContextImpl *scfToSPIRVContext,
95 ArrayRef<Type> returnTypes) {
96
97 Location loc = scfOp.getLoc();
98 auto &allocas = scfToSPIRVContext->outputVars[newOp];
99 // Clearing the allocas is necessary in case a dialect conversion path failed
100 // previously, and this is the second attempt of this conversion.
101 allocas.clear();
102 SmallVector<Value, 8> resultValue;
103 for (Type convertedType : returnTypes) {
104 auto pointerType =
105 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
106 rewriter.setInsertionPoint(newOp);
107 auto alloc = rewriter.create<spirv::VariableOp>(
108 loc, pointerType, spirv::StorageClass::Function,
109 /*initializer=*/nullptr);
110 allocas.push_back(alloc);
111 rewriter.setInsertionPointAfter(newOp);
112 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
113 resultValue.push_back(loadResult);
114 }
115 rewriter.replaceOp(scfOp, resultValue);
116 }
117
118 //===----------------------------------------------------------------------===//
119 // scf::ForOp.
120 //===----------------------------------------------------------------------===//
121
122 LogicalResult
matchAndRewrite(scf::ForOp forOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const123 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
124 ConversionPatternRewriter &rewriter) const {
125 // scf::ForOp can be lowered to the structured control flow represented by
126 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
127 // latch and the merge block the exit block. The resulting spirv::LoopOp has a
128 // single back edge from the continue to header block, and a single exit from
129 // header to merge.
130 scf::ForOpAdaptor forOperands(operands);
131 auto loc = forOp.getLoc();
132 auto loopControl = rewriter.getI32IntegerAttr(
133 static_cast<uint32_t>(spirv::LoopControl::None));
134 auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
135 loopOp.addEntryAndMergeBlock();
136
137 OpBuilder::InsertionGuard guard(rewriter);
138 // Create the block for the header.
139 auto *header = new Block();
140 // Insert the header.
141 loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
142
143 // Create the new induction variable to use.
144 BlockArgument newIndVar =
145 header->addArgument(forOperands.lowerBound().getType());
146 for (Value arg : forOperands.initArgs())
147 header->addArgument(arg.getType());
148 Block *body = forOp.getBody();
149
150 // Apply signature conversion to the body of the forOp. It has a single block,
151 // with argument which is the induction variable. That has to be replaced with
152 // the new induction variable.
153 TypeConverter::SignatureConversion signatureConverter(
154 body->getNumArguments());
155 signatureConverter.remapInput(0, newIndVar);
156 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
157 signatureConverter.remapInput(i, header->getArgument(i));
158 body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
159 signatureConverter);
160
161 // Move the blocks from the forOp into the loopOp. This is the body of the
162 // loopOp.
163 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
164 std::next(loopOp.body().begin(), 2));
165
166 SmallVector<Value, 8> args(1, forOperands.lowerBound());
167 args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
168 // Branch into it from the entry.
169 rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
170 rewriter.create<spirv::BranchOp>(loc, header, args);
171
172 // Generate the rest of the loop header.
173 rewriter.setInsertionPointToEnd(header);
174 auto *mergeBlock = loopOp.getMergeBlock();
175 auto cmpOp = rewriter.create<spirv::SLessThanOp>(
176 loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
177
178 rewriter.create<spirv::BranchConditionalOp>(
179 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
180
181 // Generate instructions to increment the step of the induction variable and
182 // branch to the header.
183 Block *continueBlock = loopOp.getContinueBlock();
184 rewriter.setInsertionPointToEnd(continueBlock);
185
186 // Add the step to the induction variable and branch to the header.
187 Value updatedIndVar = rewriter.create<spirv::IAddOp>(
188 loc, newIndVar.getType(), newIndVar, forOperands.step());
189 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
190
191 // Infer the return types from the init operands. Vector type may get
192 // converted to CooperativeMatrix or to Vector type, to avoid having complex
193 // extra logic to figure out the right type we just infer it from the Init
194 // operands.
195 SmallVector<Type, 8> initTypes;
196 for (auto arg : forOperands.initArgs())
197 initTypes.push_back(arg.getType());
198 replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
199 scfToSPIRVContext, initTypes);
200 return success();
201 }
202
203 //===----------------------------------------------------------------------===//
204 // scf::IfOp.
205 //===----------------------------------------------------------------------===//
206
207 LogicalResult
matchAndRewrite(scf::IfOp ifOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const208 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
209 ConversionPatternRewriter &rewriter) const {
210 // When lowering `scf::IfOp` we explicitly create a selection header block
211 // before the control flow diverges and a merge block where control flow
212 // subsequently converges.
213 scf::IfOpAdaptor ifOperands(operands);
214 auto loc = ifOp.getLoc();
215
216 // Create `spv.selection` operation, selection header block and merge block.
217 auto selectionControl = rewriter.getI32IntegerAttr(
218 static_cast<uint32_t>(spirv::SelectionControl::None));
219 auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
220 auto *mergeBlock =
221 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
222 rewriter.create<spirv::MergeOp>(loc);
223
224 OpBuilder::InsertionGuard guard(rewriter);
225 auto *selectionHeaderBlock =
226 rewriter.createBlock(&selectionOp.body().front());
227
228 // Inline `then` region before the merge block and branch to it.
229 auto &thenRegion = ifOp.thenRegion();
230 auto *thenBlock = &thenRegion.front();
231 rewriter.setInsertionPointToEnd(&thenRegion.back());
232 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
233 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
234
235 auto *elseBlock = mergeBlock;
236 // If `else` region is not empty, inline that region before the merge block
237 // and branch to it.
238 if (!ifOp.elseRegion().empty()) {
239 auto &elseRegion = ifOp.elseRegion();
240 elseBlock = &elseRegion.front();
241 rewriter.setInsertionPointToEnd(&elseRegion.back());
242 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
243 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
244 }
245
246 // Create a `spv.BranchConditional` operation for selection header block.
247 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
248 rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
249 thenBlock, ArrayRef<Value>(),
250 elseBlock, ArrayRef<Value>());
251
252 SmallVector<Type, 8> returnTypes;
253 for (auto result : ifOp.results()) {
254 auto convertedType = typeConverter.convertType(result.getType());
255 returnTypes.push_back(convertedType);
256 }
257 replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
258 scfToSPIRVContext, returnTypes);
259 return success();
260 }
261
262 /// Yield is lowered to stores to the VariableOp created during lowering of the
263 /// parent region. For loops we also need to update the branch looping back to
264 /// the header with the loop carried values.
matchAndRewrite(scf::YieldOp terminatorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const265 LogicalResult TerminatorOpConversion::matchAndRewrite(
266 scf::YieldOp terminatorOp, ArrayRef<Value> operands,
267 ConversionPatternRewriter &rewriter) const {
268 // If the region is return values, store each value into the associated
269 // VariableOp created during lowering of the parent region.
270 if (!operands.empty()) {
271 auto loc = terminatorOp.getLoc();
272 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
273 assert(allocas.size() == operands.size());
274 for (unsigned i = 0, e = operands.size(); i < e; i++)
275 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
276 if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
277 // For loops we also need to update the branch jumping back to the header.
278 auto br =
279 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
280 SmallVector<Value, 8> args(br.getBlockArguments());
281 args.append(operands.begin(), operands.end());
282 rewriter.setInsertionPoint(br);
283 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
284 args);
285 rewriter.eraseOp(br);
286 }
287 }
288 rewriter.eraseOp(terminatorOp);
289 return success();
290 }
291
populateSCFToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,ScfToSPIRVContext & scfToSPIRVContext,OwningRewritePatternList & patterns)292 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
293 SPIRVTypeConverter &typeConverter,
294 ScfToSPIRVContext &scfToSPIRVContext,
295 OwningRewritePatternList &patterns) {
296 patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
297 context, typeConverter, scfToSPIRVContext.getImpl());
298 }
299