1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to SCF dialect.
17 #include <utility>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringSwitch.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
29 #include "mlir/IR/Block.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/Diagnostics.h"
34 #include "mlir/IR/ImplicitLocOpBuilder.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeRange.h"
37 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Pass/PassRegistry.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/LogicalResult.h"
42 #include "mlir/Transforms/DialectConversion.h"
43
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47
48 // All transformations in this file take mhlo blocks which end with
49 // mhlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an
50 // entire block with the only change being return -> yield.
inlineMhloRegionIntoSCFRegion(PatternRewriter & rewriter,Region & mhlo,Region & scf)51 void inlineMhloRegionIntoSCFRegion(PatternRewriter& rewriter, Region& mhlo,
52 Region& scf) {
53 // Remove an existing block, then move the region over.
54 if (!scf.empty()) rewriter.eraseBlock(&scf.back());
55 rewriter.inlineRegionBefore(mhlo, scf, scf.end());
56 // Fix up the terminator.
57 PatternRewriter::InsertionGuard guard(rewriter);
58 rewriter.setInsertionPointToEnd(&scf.back());
59 auto* terminator = scf.back().getTerminator();
60 rewriter.replaceOpWithNewOp<scf::YieldOp>(terminator,
61 terminator->getOperands());
62 }
63
64 // mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor
65 // or a 1 element tensor. To handle this, collapse shape before extracting the
66 // scalar value when necessary.
extractTensorValue(OpBuilder & b,Value tensor)67 Value extractTensorValue(OpBuilder& b, Value tensor) {
68 auto loc = tensor.getLoc();
69 if (tensor.getType().cast<TensorType>().hasRank() &&
70 tensor.getType().cast<TensorType>().getRank() != 0) {
71 tensor = b.create<tensor::CollapseShapeOp>(
72 loc, tensor, SmallVector<ReassociationIndices>());
73 }
74 return b.create<tensor::ExtractOp>(loc, tensor, ValueRange());
75 }
76
77 // Create a memref descriptor given a pointer and memref type information.
78 struct WhileOpPattern : public OpConversionPattern<mhlo::WhileOp> {
79 using OpConversionPattern<WhileOp>::OpConversionPattern;
80
matchAndRewritemlir::mhlo::__anon0611b08e0111::WhileOpPattern81 LogicalResult matchAndRewrite(
82 mhlo::WhileOp op, OpAdaptor adaptor,
83 ConversionPatternRewriter& rewriter) const override {
84 auto loc = op.getLoc();
85
86 auto newWhileOp = rewriter.create<scf::WhileOp>(loc, op.getResultTypes(),
87 adaptor.getOperands());
88
89 // Inline while condition. The block is the same, except the boolean result
90 // needs to be extracted and used with an scf.condition.
91 rewriter.inlineRegionBefore(op.cond(), newWhileOp.getBefore(),
92 newWhileOp.getBefore().end());
93 auto conditionReturn =
94 cast<mhlo::ReturnOp>(newWhileOp.getBefore().front().getTerminator());
95 rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front());
96 Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0));
97 rewriter.replaceOpWithNewOp<scf::ConditionOp>(
98 conditionReturn, i1, newWhileOp.getBeforeArguments());
99
100 // Inline while body, and only replace the mhlo.return with an scf.yield.
101 inlineMhloRegionIntoSCFRegion(rewriter, op.body(), newWhileOp.getAfter());
102
103 rewriter.replaceOp(op, newWhileOp.getResults());
104 return success();
105 }
106 };
107
108 // Create a memref descriptor given a pointer and memref type information.
109 struct IfOpPattern : public OpConversionPattern<mhlo::IfOp> {
110 using OpConversionPattern<IfOp>::OpConversionPattern;
111
matchAndRewritemlir::mhlo::__anon0611b08e0111::IfOpPattern112 LogicalResult matchAndRewrite(
113 mhlo::IfOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter& rewriter) const override {
115 auto scfIf =
116 rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
117 extractTensorValue(rewriter, adaptor.pred()),
118 /*withElseRegion=*/true);
119 inlineMhloRegionIntoSCFRegion(rewriter, op.true_branch(),
120 scfIf.getThenRegion());
121 inlineMhloRegionIntoSCFRegion(rewriter, op.false_branch(),
122 scfIf.getElseRegion());
123 rewriter.replaceOp(op, scfIf.getResults());
124 return success();
125 }
126 };
127
128 // Create a memref descriptor given a pointer and memref type information.
129 struct CaseOpPattern : public OpConversionPattern<mhlo::CaseOp> {
130 using OpConversionPattern<CaseOp>::OpConversionPattern;
131
132 // Recursively create if/else ops to handle each possible value in a case op.
createNestedCasesmlir::mhlo::__anon0611b08e0111::CaseOpPattern133 scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor,
134 PatternRewriter& outerBuilder) const {
135 Location loc = op.getLoc();
136 Value idxValue = adaptor.index();
137 auto finalIdx = op.branches().size() - 2;
138
139 // Determine if the current index matches the case index.
140 auto scalarType = idxValue.getType();
141 auto constAttr = DenseElementsAttr::get(
142 scalarType,
143 {outerBuilder.getI32IntegerAttr(currentIdx).cast<mlir::Attribute>()});
144 Value currentIdxVal = outerBuilder.create<mhlo::ConstantOp>(
145 loc, idxValue.getType(), constAttr);
146
147 auto scfIf = outerBuilder.create<scf::IfOp>(
148 loc, op.getResultTypes(),
149 extractTensorValue(outerBuilder, outerBuilder.create<mhlo::CompareOp>(
150 loc, idxValue, currentIdxVal,
151 ComparisonDirection::EQ)),
152 /*withElseRegion=*/true);
153 inlineMhloRegionIntoSCFRegion(outerBuilder, op.branches()[currentIdx],
154 scfIf.getThenRegion());
155 int nextIdx = currentIdx + 1;
156 // Don't recurse for the final default block.
157 if (currentIdx == static_cast<int64_t>(finalIdx)) {
158 inlineMhloRegionIntoSCFRegion(outerBuilder, op.branches()[nextIdx],
159 scfIf.getElseRegion());
160 } else {
161 PatternRewriter::InsertionGuard guard(outerBuilder);
162 outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back());
163 auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder);
164 outerBuilder.create<scf::YieldOp>(op.getLoc(), innerIf.getResults());
165 }
166 return scfIf;
167 }
168
matchAndRewritemlir::mhlo::__anon0611b08e0111::CaseOpPattern169 LogicalResult matchAndRewrite(
170 mhlo::CaseOp op, OpAdaptor adaptor,
171 ConversionPatternRewriter& rewriter) const override {
172 // Inline the op if there is only a default block.
173 if (op.branches().size() == 1) {
174 Block& block = op.branches().front().front();
175 auto results = block.getTerminator()->getOperands();
176 // Remove the mhlo.return terminator, then inline the block.
177 rewriter.eraseOp(block.getTerminator());
178 rewriter.mergeBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(),
179 /*argValues=*/{});
180 rewriter.replaceOp(op, results);
181 return success();
182 }
183
184 // Begin recursion with case 0.
185 rewriter.replaceOp(
186 op, createNestedCases(0, op, adaptor, rewriter).getResults());
187 return success();
188 }
189 };
190
191 struct LegalizeControlFlowPass
192 : public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
193 // Perform the lowering to MLIR control flow.
runOnOperationmlir::mhlo::__anon0611b08e0111::LegalizeControlFlowPass194 void runOnOperation() override {
195 func::FuncOp f = getOperation();
196 MLIRContext* ctx = f.getContext();
197
198 RewritePatternSet patterns(&getContext());
199 patterns.add<WhileOpPattern, IfOpPattern, CaseOpPattern>(&getContext());
200
201 mlir::ConversionTarget target(*ctx);
202 target.markUnknownOpDynamicallyLegal([](Operation*) { return true; });
203 target.addIllegalOp<mhlo::IfOp, mhlo::WhileOp, mhlo::CaseOp>();
204
205 if (failed(applyPartialConversion(f, target, std::move(patterns)))) {
206 signalPassFailure();
207 }
208 }
209 };
210
211 } // namespace
212 } // namespace mhlo
213 } // namespace mlir
214
215 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeControlFlowPass()216 mlir::mhlo::createLegalizeControlFlowPass() {
217 return std::make_unique<LegalizeControlFlowPass>();
218 }
219