1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
26 #include "mlir/IR/Block.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Pass/PassRegistry.h"
35 #include "mlir/Support/LogicalResult.h"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 struct LegalizeControlFlowPass
41 : public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
42 // Perform the lowering to MLIR control flow.
43 void runOnFunction() override;
44 };
45
46 // Replaces terminators for the newly created blocks from a targe region.
47 // These terminators are replaced with branch operations to a target block.
ReplaceTerminators(Region * region,Block * target_block,Location loc,const BlockAndValueMapping & mapper,OpBuilder * builder)48 LogicalResult ReplaceTerminators(Region* region, Block* target_block,
49 Location loc,
50 const BlockAndValueMapping& mapper,
51 OpBuilder* builder) {
52 for (auto& old_block : region->getBlocks()) {
53 Block* block = mapper.lookup(&old_block);
54 auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
55 if (!return_op) continue;
56 builder->setInsertionPointToEnd(block);
57 builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
58 return_op.erase();
59 }
60
61 return success();
62 }
63
LowerIfOp(mlir::mhlo::IfOp if_op)64 LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
65 Operation* op_inst = if_op.getOperation();
66 mlir::OpBuilder builder(if_op);
67 auto orig_block = op_inst->getBlock();
68 auto* tail_block = orig_block->splitBlock(op_inst);
69 auto loc = if_op.getLoc();
70
71 // Duplicate the true and false regions in the block between the sections
72 // before and after the conditional.
73 BlockAndValueMapping mapper;
74 if_op.true_branch().cloneInto(orig_block->getParent(),
75 Region::iterator(tail_block), mapper);
76 if_op.false_branch().cloneInto(orig_block->getParent(),
77 Region::iterator(tail_block), mapper);
78
79 // Determine the blocks for the start of the true and false regions.
80 Block* true_block = mapper.lookup(&if_op.true_branch().front());
81 Block* false_block = mapper.lookup(&if_op.false_branch().front());
82
83 // Perform the conditional branch into the true/false cases.
84 builder.setInsertionPointToEnd(orig_block);
85
86 // Extract the predicate for checking branching, then branch to the true and
87 // false regions appropriately.
88 auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
89 builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
90 if_op.true_arg(), false_block,
91 if_op.false_arg());
92
93 // Replace the true case's return operations with a branch to the tail of
94 // the condition.
95 if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
96 &builder)))
97 return failure();
98 if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
99 &builder)))
100 return failure();
101
102 tail_block->addArguments(if_op.getResult().getType());
103 if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
104
105 op_inst->erase();
106 return success();
107 }
108
LowerWhileOp(mlir::mhlo::WhileOp while_op)109 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
110 // TODO(jpienaar): Support multi-operand while op.
111 if (while_op.arg().size() != 1) return failure();
112
113 // Converts a MHLO while loop into control flow. This generates a set of MLIR
114 // blocks and branches, along with inlining the regions provided by the MHLO
115 // while loop. The structure should be similar to below:
116 //
117 // <prior operations>
118 // %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
119 // <post operations>
120 auto* op_inst = while_op.getOperation();
121 mlir::OpBuilder builder(while_op);
122 auto loc = while_op.getLoc();
123
124 // Break the block into four sections:
125 // orig_block - operations before the while and the branch into looping check.
126 // tail_block - operations after the while loop completes.
127 // cond_block - check the looping condition, then conditionally branch into
128 // the loop or, if condition is false, jump to the tail branch.
129 // body_block - inlined loop body, then jump back to the condition block.
130 auto* orig_block = op_inst->getBlock();
131 auto* tail_block = orig_block->splitBlock(op_inst);
132
133 BlockAndValueMapping mapper;
134 while_op.cond().cloneInto(orig_block->getParent(),
135 Region::iterator(tail_block), mapper);
136 while_op.body().cloneInto(orig_block->getParent(),
137 Region::iterator(tail_block), mapper);
138
139 // Lookup the entry blocks for both condition and body.
140 auto* cond_block = mapper.lookup(&while_op.cond().front());
141 auto* body_block = mapper.lookup(&while_op.body().front());
142
143 // Setup the end of the original block:
144 // <prior operations>
145 // br ^cond(%arg0) // Jumps to the condition statement.
146 builder.setInsertionPointToEnd(orig_block);
147 // TODO(jpienaar): Support multi-operand while op.
148 builder.create<mlir::BranchOp>(loc, cond_block, while_op.arg()[0]);
149
150 // Updates the inlined condition blocks by replacing the return op with an
151 // tensor.extract and conditional branch. This changes the block below:
152 // ^cond(%0):
153 // <inlined conditional region>
154 // "mhlo".return(%1)
155 //
156 // Into:
157 // ^cond(%0):
158 // <inlined conditional region>
159 // %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
160 // cond_br %2, ^body(%0), ^tail(%0) // Branch.
161 builder.setInsertionPointToStart(cond_block);
162
163 // Replace the mhlo::ReturnOp with a branch back to the condition block.
164 // This is required as the mhlo::ReturnOp is used to mark the end of a
165 // block for regions nested inside of a operations (MLIR ReturnOp cannot be
166 // nested within an non-function region).
167 for (auto& block : while_op.cond()) {
168 auto new_block = mapper.lookup(&block);
169
170 auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
171 if (!return_op) continue;
172 builder.setInsertionPointToEnd(new_block);
173
174 auto return_value = return_op.getOperand(0);
175 auto cond_value =
176 builder.create<mlir::tensor::ExtractOp>(loc, return_value);
177
178 // Get the body block arguments.
179 llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
180 cond_block->args_end());
181 builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
182 successor_args, tail_block,
183 successor_args);
184 return_op.erase();
185 }
186
187 // Updates the body blocks by replace the return op with an branch to the
188 // conditional block. This changes the block below:
189 // ^body(%0):
190 // <inlined body block>
191 // "mhlo".return(%1)
192 //
193 // Into:
194 // ^body(%0):
195 // <inlined body block>
196 // br ^cond(%0) // Branch.
197 for (auto& block : while_op.body()) {
198 auto new_block = mapper.lookup(&block);
199 auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
200 if (!return_op) continue;
201 builder.setInsertionPointToEnd(new_block);
202 builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
203 return_op.erase();
204 }
205
206 // Erase the original while loop.
207 // TODO(jpienaar): Support multi-operand while op.
208 tail_block->addArgument(while_op.arg().getType()[0]);
209 while_op.getResult(0).replaceAllUsesWith(tail_block->getArgument(0));
210 op_inst->erase();
211
212 return success();
213 }
214
runOnFunction()215 void LegalizeControlFlowPass::runOnFunction() {
216 auto func = getFunction();
217 llvm::SmallVector<IfOp, 4> if_ops;
218 func.walk([&](IfOp op) { if_ops.push_back(op); });
219
220 for (auto& op : if_ops) {
221 if (failed(LowerIfOp(op))) return signalPassFailure();
222 }
223
224 llvm::SmallVector<WhileOp, 4> while_ops;
225 func.walk([&](WhileOp op) { while_ops.push_back(op); });
226
227 for (auto& op : while_ops) {
228 if (failed(LowerWhileOp(op))) return signalPassFailure();
229 }
230 }
231 } // namespace
232 } // namespace mhlo
233 } // namespace mlir
234
235 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeControlFlowPass()236 mlir::mhlo::createLegalizeControlFlowPass() {
237 return std::make_unique<LegalizeControlFlowPass>();
238 }
239