1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/Passes.h"
11 #include "mlir/Dialect/SCF/SCF.h"
12 #include "mlir/Dialect/SCF/Transforms.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Transforms/DialectConversion.h"
15
16 using namespace mlir;
17 using namespace mlir::scf;
18
19 namespace {
20 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
21 public:
22 using OpConversionPattern::OpConversionPattern;
23 LogicalResult
matchAndRewrite(ForOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24 matchAndRewrite(ForOp op, ArrayRef<Value> operands,
25 ConversionPatternRewriter &rewriter) const override {
26 SmallVector<Type, 6> newResultTypes;
27 for (auto type : op.getResultTypes()) {
28 Type newType = typeConverter->convertType(type);
29 if (!newType)
30 return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
31 newResultTypes.push_back(newType);
32 }
33
34 // Clone the op without the regions and inline the regions from the old op.
35 //
36 // This is a little bit tricky. We have two concerns here:
37 //
38 // 1. We cannot update the op in place because the dialect conversion
39 // framework does not track type changes for ops updated in place, so it
40 // won't insert appropriate materializations on the changed result types.
41 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
42 // clone the op.
43 //
44 // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
45 // inefficient to recursively clone the regions, there is a correctness
46 // issue: if we clone with the regions, then the dialect conversion
47 // framework thinks that we just inserted all the cloned child ops. But what
48 // we want is to "take" the child regions and let the dialect conversion
49 // framework continue recursively into ops inside those regions (which are
50 // already in its worklist; inlining them into the new op's regions doesn't
51 // remove the child ops from the worklist).
52 ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
53 // Take the region from the old op and put it in the new op.
54 rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
55 newOp.getLoopBody().end());
56
57 // Now, update all the types.
58
59 // Convert the type of the entry block of the ForOp's body.
60 if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
61 *getTypeConverter()))) {
62 return rewriter.notifyMatchFailure(op, "could not convert body types");
63 }
64 // Change the clone to use the updated operands. We could have cloned with
65 // a BlockAndValueMapping, but this seems a bit more direct.
66 newOp->setOperands(operands);
67 // Update the result types to the new converted types.
68 for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
69 std::get<0>(t).setType(std::get<1>(t));
70
71 rewriter.replaceOp(op, newOp.getResults());
72 return success();
73 }
74 };
75 } // namespace
76
77 namespace {
78 class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
79 public:
80 using OpConversionPattern::OpConversionPattern;
81 LogicalResult
matchAndRewrite(IfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const82 matchAndRewrite(IfOp op, ArrayRef<Value> operands,
83 ConversionPatternRewriter &rewriter) const override {
84 // TODO: Generalize this to any type conversion, not just 1:1.
85 //
86 // We need to implement something more sophisticated here that tracks which
87 // types convert to which other types and does the appropriate
88 // materialization logic.
89 // For example, it's possible that one result type converts to 0 types and
90 // another to 2 types, so newResultTypes would at least be the right size to
91 // not crash in the llvm::zip call below, but then we would set the the
92 // wrong type on the SSA values! These edge cases are also why we cannot
93 // safely use the TypeConverter::convertTypes helper here.
94 SmallVector<Type, 6> newResultTypes;
95 for (auto type : op.getResultTypes()) {
96 Type newType = typeConverter->convertType(type);
97 if (!newType)
98 return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
99 newResultTypes.push_back(newType);
100 }
101
102 // See comments in the ForOp pattern for why we clone without regions and
103 // then inline.
104 IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
105 rewriter.inlineRegionBefore(op.thenRegion(), newOp.thenRegion(),
106 newOp.thenRegion().end());
107 rewriter.inlineRegionBefore(op.elseRegion(), newOp.elseRegion(),
108 newOp.elseRegion().end());
109
110 // Update the operands and types.
111 newOp->setOperands(operands);
112 for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
113 std::get<0>(t).setType(std::get<1>(t));
114 rewriter.replaceOp(op, newOp.getResults());
115 return success();
116 }
117 };
118 } // namespace
119
120 namespace {
121 // When the result types of a ForOp/IfOp get changed, the operand types of the
122 // corresponding yield op need to be changed. In order to trigger the
123 // appropriate type conversions / materializations, we need a dummy pattern.
124 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
125 public:
126 using OpConversionPattern::OpConversionPattern;
127 LogicalResult
matchAndRewrite(scf::YieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const128 matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
129 ConversionPatternRewriter &rewriter) const override {
130 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
131 return success();
132 }
133 };
134 } // namespace
135
populateSCFStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)136 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
137 MLIRContext *context, TypeConverter &typeConverter,
138 OwningRewritePatternList &patterns, ConversionTarget &target) {
139 patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
140 typeConverter, context);
141 target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
142 return typeConverter.isLegal(op->getResultTypes());
143 });
144 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
145 // We only have conversions for a subset of ops that use scf.yield
146 // terminators.
147 if (!isa<ForOp, IfOp>(op->getParentOp()))
148 return true;
149 return typeConverter.isLegal(op.getOperandTypes());
150 });
151 }
152