• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/Passes.h"
11 #include "mlir/Dialect/SCF/SCF.h"
12 #include "mlir/Dialect/SCF/Transforms.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 using namespace mlir::scf;
18 
19 namespace {
20 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
21 public:
22   using OpConversionPattern::OpConversionPattern;
23   LogicalResult
matchAndRewrite(ForOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24   matchAndRewrite(ForOp op, ArrayRef<Value> operands,
25                   ConversionPatternRewriter &rewriter) const override {
26     SmallVector<Type, 6> newResultTypes;
27     for (auto type : op.getResultTypes()) {
28       Type newType = typeConverter->convertType(type);
29       if (!newType)
30         return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
31       newResultTypes.push_back(newType);
32     }
33 
34     // Clone the op without the regions and inline the regions from the old op.
35     //
36     // This is a little bit tricky. We have two concerns here:
37     //
38     // 1. We cannot update the op in place because the dialect conversion
39     // framework does not track type changes for ops updated in place, so it
40     // won't insert appropriate materializations on the changed result types.
41     // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
42     // clone the op.
43     //
44     // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
45     // inefficient to recursively clone the regions, there is a correctness
46     // issue: if we clone with the regions, then the dialect conversion
47     // framework thinks that we just inserted all the cloned child ops. But what
48     // we want is to "take" the child regions and let the dialect conversion
49     // framework continue recursively into ops inside those regions (which are
50     // already in its worklist; inlining them into the new op's regions doesn't
51     // remove the child ops from the worklist).
52     ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
53     // Take the region from the old op and put it in the new op.
54     rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
55                                 newOp.getLoopBody().end());
56 
57     // Now, update all the types.
58 
59     // Convert the type of the entry block of the ForOp's body.
60     if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
61                                            *getTypeConverter()))) {
62       return rewriter.notifyMatchFailure(op, "could not convert body types");
63     }
64     // Change the clone to use the updated operands. We could have cloned with
65     // a BlockAndValueMapping, but this seems a bit more direct.
66     newOp->setOperands(operands);
67     // Update the result types to the new converted types.
68     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
69       std::get<0>(t).setType(std::get<1>(t));
70 
71     rewriter.replaceOp(op, newOp.getResults());
72     return success();
73   }
74 };
75 } // namespace
76 
77 namespace {
78 class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
79 public:
80   using OpConversionPattern::OpConversionPattern;
81   LogicalResult
matchAndRewrite(IfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const82   matchAndRewrite(IfOp op, ArrayRef<Value> operands,
83                   ConversionPatternRewriter &rewriter) const override {
84     // TODO: Generalize this to any type conversion, not just 1:1.
85     //
86     // We need to implement something more sophisticated here that tracks which
87     // types convert to which other types and does the appropriate
88     // materialization logic.
89     // For example, it's possible that one result type converts to 0 types and
90     // another to 2 types, so newResultTypes would at least be the right size to
91     // not crash in the llvm::zip call below, but then we would set the the
92     // wrong type on the SSA values! These edge cases are also why we cannot
93     // safely use the TypeConverter::convertTypes helper here.
94     SmallVector<Type, 6> newResultTypes;
95     for (auto type : op.getResultTypes()) {
96       Type newType = typeConverter->convertType(type);
97       if (!newType)
98         return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
99       newResultTypes.push_back(newType);
100     }
101 
102     // See comments in the ForOp pattern for why we clone without regions and
103     // then inline.
104     IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
105     rewriter.inlineRegionBefore(op.thenRegion(), newOp.thenRegion(),
106                                 newOp.thenRegion().end());
107     rewriter.inlineRegionBefore(op.elseRegion(), newOp.elseRegion(),
108                                 newOp.elseRegion().end());
109 
110     // Update the operands and types.
111     newOp->setOperands(operands);
112     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
113       std::get<0>(t).setType(std::get<1>(t));
114     rewriter.replaceOp(op, newOp.getResults());
115     return success();
116   }
117 };
118 } // namespace
119 
120 namespace {
121 // When the result types of a ForOp/IfOp get changed, the operand types of the
122 // corresponding yield op need to be changed. In order to trigger the
123 // appropriate type conversions / materializations, we need a dummy pattern.
124 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
125 public:
126   using OpConversionPattern::OpConversionPattern;
127   LogicalResult
matchAndRewrite(scf::YieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const128   matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
129                   ConversionPatternRewriter &rewriter) const override {
130     rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
131     return success();
132   }
133 };
134 } // namespace
135 
populateSCFStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)136 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
137     MLIRContext *context, TypeConverter &typeConverter,
138     OwningRewritePatternList &patterns, ConversionTarget &target) {
139   patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
140       typeConverter, context);
141   target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
142     return typeConverter.isLegal(op->getResultTypes());
143   });
144   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
145     // We only have conversions for a subset of ops that use scf.yield
146     // terminators.
147     if (!isa<ForOp, IfOp>(op->getParentOp()))
148       return true;
149     return typeConverter.isLegal(op.getOperandTypes());
150   });
151 }
152