• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the conversion patterns from SCF ops to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/IR/BuiltinOps.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 struct ScfToSPIRVContextImpl {
23   // Map between the spirv region control flow operation (spv.loop or
24   // spv.selection) to the VariableOp created to store the region results. The
25   // order of the VariableOp matches the order of the results.
26   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
27 };
28 } // namespace mlir
29 
30 /// We use ScfToSPIRVContext to store information about the lowering of the scf
31 /// region that need to be used later on. When we lower scf.for/scf.if we create
32 /// VariableOp to store the results. We need to keep track of the VariableOp
33 /// created as we need to insert stores into them when lowering Yield. Those
34 /// StoreOp cannot be created earlier as they may use a different type than
35 /// yield operands.
ScfToSPIRVContext()36 ScfToSPIRVContext::ScfToSPIRVContext() {
37   impl = std::make_unique<ScfToSPIRVContextImpl>();
38 }
39 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
40 
41 namespace {
42 /// Common class for all vector to GPU patterns.
43 template <typename OpTy>
44 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
45 public:
SCFToSPIRVPattern(MLIRContext * context,SPIRVTypeConverter & converter,ScfToSPIRVContextImpl * scfToSPIRVContext)46   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
47                           ScfToSPIRVContextImpl *scfToSPIRVContext)
48       : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
49         scfToSPIRVContext(scfToSPIRVContext) {}
50 
51 protected:
52   ScfToSPIRVContextImpl *scfToSPIRVContext;
53 };
54 
55 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
56 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
57 public:
58   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
59 
60   LogicalResult
61   matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override;
63 };
64 
65 /// Pattern to convert a scf::IfOp within kernel functions into
66 /// spirv::SelectionOp.
67 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
68 public:
69   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
70 
71   LogicalResult
72   matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 
76 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
77 public:
78   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
79 
80   LogicalResult
81   matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
82                   ConversionPatternRewriter &rewriter) const override;
83 };
84 } // namespace
85 
86 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
87 /// We create VariableOp to handle the results value of the control flow region.
88 /// spv.loop/spv.selection currently don't yield value. Right after the loop
89 /// we load the value from the allocation and use it as the SCF op result.
90 template <typename ScfOp, typename OpTy>
replaceSCFOutputValue(ScfOp scfOp,OpTy newOp,SPIRVTypeConverter & typeConverter,ConversionPatternRewriter & rewriter,ScfToSPIRVContextImpl * scfToSPIRVContext,ArrayRef<Type> returnTypes)91 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
92                                   SPIRVTypeConverter &typeConverter,
93                                   ConversionPatternRewriter &rewriter,
94                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
95                                   ArrayRef<Type> returnTypes) {
96 
97   Location loc = scfOp.getLoc();
98   auto &allocas = scfToSPIRVContext->outputVars[newOp];
99   // Clearing the allocas is necessary in case a dialect conversion path failed
100   // previously, and this is the second attempt of this conversion.
101   allocas.clear();
102   SmallVector<Value, 8> resultValue;
103   for (Type convertedType : returnTypes) {
104     auto pointerType =
105         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
106     rewriter.setInsertionPoint(newOp);
107     auto alloc = rewriter.create<spirv::VariableOp>(
108         loc, pointerType, spirv::StorageClass::Function,
109         /*initializer=*/nullptr);
110     allocas.push_back(alloc);
111     rewriter.setInsertionPointAfter(newOp);
112     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
113     resultValue.push_back(loadResult);
114   }
115   rewriter.replaceOp(scfOp, resultValue);
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // scf::ForOp.
120 //===----------------------------------------------------------------------===//
121 
122 LogicalResult
matchAndRewrite(scf::ForOp forOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const123 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
124                                  ConversionPatternRewriter &rewriter) const {
125   // scf::ForOp can be lowered to the structured control flow represented by
126   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
127   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
128   // single back edge from the continue to header block, and a single exit from
129   // header to merge.
130   scf::ForOpAdaptor forOperands(operands);
131   auto loc = forOp.getLoc();
132   auto loopControl = rewriter.getI32IntegerAttr(
133       static_cast<uint32_t>(spirv::LoopControl::None));
134   auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
135   loopOp.addEntryAndMergeBlock();
136 
137   OpBuilder::InsertionGuard guard(rewriter);
138   // Create the block for the header.
139   auto *header = new Block();
140   // Insert the header.
141   loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
142 
143   // Create the new induction variable to use.
144   BlockArgument newIndVar =
145       header->addArgument(forOperands.lowerBound().getType());
146   for (Value arg : forOperands.initArgs())
147     header->addArgument(arg.getType());
148   Block *body = forOp.getBody();
149 
150   // Apply signature conversion to the body of the forOp. It has a single block,
151   // with argument which is the induction variable. That has to be replaced with
152   // the new induction variable.
153   TypeConverter::SignatureConversion signatureConverter(
154       body->getNumArguments());
155   signatureConverter.remapInput(0, newIndVar);
156   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
157     signatureConverter.remapInput(i, header->getArgument(i));
158   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
159                                            signatureConverter);
160 
161   // Move the blocks from the forOp into the loopOp. This is the body of the
162   // loopOp.
163   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
164                               std::next(loopOp.body().begin(), 2));
165 
166   SmallVector<Value, 8> args(1, forOperands.lowerBound());
167   args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
168   // Branch into it from the entry.
169   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
170   rewriter.create<spirv::BranchOp>(loc, header, args);
171 
172   // Generate the rest of the loop header.
173   rewriter.setInsertionPointToEnd(header);
174   auto *mergeBlock = loopOp.getMergeBlock();
175   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
176       loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
177 
178   rewriter.create<spirv::BranchConditionalOp>(
179       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
180 
181   // Generate instructions to increment the step of the induction variable and
182   // branch to the header.
183   Block *continueBlock = loopOp.getContinueBlock();
184   rewriter.setInsertionPointToEnd(continueBlock);
185 
186   // Add the step to the induction variable and branch to the header.
187   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
188       loc, newIndVar.getType(), newIndVar, forOperands.step());
189   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
190 
191   // Infer the return types from the init operands. Vector type may get
192   // converted to CooperativeMatrix or to Vector type, to avoid having complex
193   // extra logic to figure out the right type we just infer it from the Init
194   // operands.
195   SmallVector<Type, 8> initTypes;
196   for (auto arg : forOperands.initArgs())
197     initTypes.push_back(arg.getType());
198   replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
199                         scfToSPIRVContext, initTypes);
200   return success();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // scf::IfOp.
205 //===----------------------------------------------------------------------===//
206 
207 LogicalResult
matchAndRewrite(scf::IfOp ifOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const208 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
209                                 ConversionPatternRewriter &rewriter) const {
210   // When lowering `scf::IfOp` we explicitly create a selection header block
211   // before the control flow diverges and a merge block where control flow
212   // subsequently converges.
213   scf::IfOpAdaptor ifOperands(operands);
214   auto loc = ifOp.getLoc();
215 
216   // Create `spv.selection` operation, selection header block and merge block.
217   auto selectionControl = rewriter.getI32IntegerAttr(
218       static_cast<uint32_t>(spirv::SelectionControl::None));
219   auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
220   auto *mergeBlock =
221       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
222   rewriter.create<spirv::MergeOp>(loc);
223 
224   OpBuilder::InsertionGuard guard(rewriter);
225   auto *selectionHeaderBlock =
226       rewriter.createBlock(&selectionOp.body().front());
227 
228   // Inline `then` region before the merge block and branch to it.
229   auto &thenRegion = ifOp.thenRegion();
230   auto *thenBlock = &thenRegion.front();
231   rewriter.setInsertionPointToEnd(&thenRegion.back());
232   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
233   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
234 
235   auto *elseBlock = mergeBlock;
236   // If `else` region is not empty, inline that region before the merge block
237   // and branch to it.
238   if (!ifOp.elseRegion().empty()) {
239     auto &elseRegion = ifOp.elseRegion();
240     elseBlock = &elseRegion.front();
241     rewriter.setInsertionPointToEnd(&elseRegion.back());
242     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
243     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
244   }
245 
246   // Create a `spv.BranchConditional` operation for selection header block.
247   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
248   rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
249                                               thenBlock, ArrayRef<Value>(),
250                                               elseBlock, ArrayRef<Value>());
251 
252   SmallVector<Type, 8> returnTypes;
253   for (auto result : ifOp.results()) {
254     auto convertedType = typeConverter.convertType(result.getType());
255     returnTypes.push_back(convertedType);
256   }
257   replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
258                         scfToSPIRVContext, returnTypes);
259   return success();
260 }
261 
262 /// Yield is lowered to stores to the VariableOp created during lowering of the
263 /// parent region. For loops we also need to update the branch looping back to
264 /// the header with the loop carried values.
matchAndRewrite(scf::YieldOp terminatorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const265 LogicalResult TerminatorOpConversion::matchAndRewrite(
266     scf::YieldOp terminatorOp, ArrayRef<Value> operands,
267     ConversionPatternRewriter &rewriter) const {
268   // If the region is return values, store each value into the associated
269   // VariableOp created during lowering of the parent region.
270   if (!operands.empty()) {
271     auto loc = terminatorOp.getLoc();
272     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
273     assert(allocas.size() == operands.size());
274     for (unsigned i = 0, e = operands.size(); i < e; i++)
275       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
276     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
277       // For loops we also need to update the branch jumping back to the header.
278       auto br =
279           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
280       SmallVector<Value, 8> args(br.getBlockArguments());
281       args.append(operands.begin(), operands.end());
282       rewriter.setInsertionPoint(br);
283       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
284                                        args);
285       rewriter.eraseOp(br);
286     }
287   }
288   rewriter.eraseOp(terminatorOp);
289   return success();
290 }
291 
populateSCFToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,ScfToSPIRVContext & scfToSPIRVContext,OwningRewritePatternList & patterns)292 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
293                                       SPIRVTypeConverter &typeConverter,
294                                       ScfToSPIRVContext &scfToSPIRVContext,
295                                       OwningRewritePatternList &patterns) {
296   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
297       context, typeConverter, scfToSPIRVContext.getImpl());
298 }
299