• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
26 #include "mlir/IR/Block.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Pass/PassRegistry.h"
35 #include "mlir/Support/LogicalResult.h"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 struct LegalizeControlFlowPass
41     : public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
42   // Perform the lowering to MLIR control flow.
43   void runOnFunction() override;
44 };
45 
46 // Replaces terminators for the newly created blocks from a targe region.
47 // These terminators are replaced with branch operations to a target block.
ReplaceTerminators(Region * region,Block * target_block,Location loc,const BlockAndValueMapping & mapper,OpBuilder * builder)48 LogicalResult ReplaceTerminators(Region* region, Block* target_block,
49                                  Location loc,
50                                  const BlockAndValueMapping& mapper,
51                                  OpBuilder* builder) {
52   for (auto& old_block : region->getBlocks()) {
53     Block* block = mapper.lookup(&old_block);
54     auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
55     if (!return_op) continue;
56     builder->setInsertionPointToEnd(block);
57     builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
58     return_op.erase();
59   }
60 
61   return success();
62 }
63 
LowerIfOp(mlir::mhlo::IfOp if_op)64 LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
65   Operation* op_inst = if_op.getOperation();
66   mlir::OpBuilder builder(if_op);
67   auto orig_block = op_inst->getBlock();
68   auto* tail_block = orig_block->splitBlock(op_inst);
69   auto loc = if_op.getLoc();
70 
71   // Duplicate the true and false regions in the block between the sections
72   // before and after the conditional.
73   BlockAndValueMapping mapper;
74   if_op.true_branch().cloneInto(orig_block->getParent(),
75                                 Region::iterator(tail_block), mapper);
76   if_op.false_branch().cloneInto(orig_block->getParent(),
77                                  Region::iterator(tail_block), mapper);
78 
79   // Determine the blocks for the start of the true and false regions.
80   Block* true_block = mapper.lookup(&if_op.true_branch().front());
81   Block* false_block = mapper.lookup(&if_op.false_branch().front());
82 
83   // Perform the conditional branch into the true/false cases.
84   builder.setInsertionPointToEnd(orig_block);
85 
86   // Extract the predicate for checking branching, then branch to the true and
87   // false regions appropriately.
88   auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
89   builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
90                                      if_op.true_arg(), false_block,
91                                      if_op.false_arg());
92 
93   // Replace the true case's return operations with a branch to the tail of
94   // the condition.
95   if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
96                                 &builder)))
97     return failure();
98   if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
99                                 &builder)))
100     return failure();
101 
102   tail_block->addArguments(if_op.getResult().getType());
103   if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
104 
105   op_inst->erase();
106   return success();
107 }
108 
LowerWhileOp(mlir::mhlo::WhileOp while_op)109 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
110   // TODO(jpienaar): Support multi-operand while op.
111   if (while_op.arg().size() != 1) return failure();
112 
113   // Converts a MHLO while loop into control flow. This generates a set of MLIR
114   // blocks and branches, along with inlining the regions provided by the MHLO
115   // while loop. The structure should be similar to below:
116   //
117   //   <prior operations>
118   //   %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
119   //   <post operations>
120   auto* op_inst = while_op.getOperation();
121   mlir::OpBuilder builder(while_op);
122   auto loc = while_op.getLoc();
123 
124   // Break the block into four sections:
125   // orig_block - operations before the while and the branch into looping check.
126   // tail_block - operations after the while loop completes.
127   // cond_block - check the looping condition, then conditionally branch into
128   //              the loop or, if condition is false, jump to the tail branch.
129   // body_block - inlined loop body, then jump back to the condition block.
130   auto* orig_block = op_inst->getBlock();
131   auto* tail_block = orig_block->splitBlock(op_inst);
132 
133   BlockAndValueMapping mapper;
134   while_op.cond().cloneInto(orig_block->getParent(),
135                             Region::iterator(tail_block), mapper);
136   while_op.body().cloneInto(orig_block->getParent(),
137                             Region::iterator(tail_block), mapper);
138 
139   // Lookup the entry blocks for both condition and body.
140   auto* cond_block = mapper.lookup(&while_op.cond().front());
141   auto* body_block = mapper.lookup(&while_op.body().front());
142 
143   // Setup the end of the original block:
144   //     <prior operations>
145   //     br ^cond(%arg0) // Jumps to the condition statement.
146   builder.setInsertionPointToEnd(orig_block);
147   // TODO(jpienaar): Support multi-operand while op.
148   builder.create<mlir::BranchOp>(loc, cond_block, while_op.arg()[0]);
149 
150   // Updates the inlined condition blocks by replacing the return op with an
151   // tensor.extract and conditional branch. This changes the block below:
152   //   ^cond(%0):
153   //     <inlined conditional region>
154   //    "mhlo".return(%1)
155   //
156   //  Into:
157   //   ^cond(%0):
158   //     <inlined conditional region>
159   //     %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
160   //     cond_br %2, ^body(%0), ^tail(%0) // Branch.
161   builder.setInsertionPointToStart(cond_block);
162 
163   // Replace the mhlo::ReturnOp with a branch back to the condition block.
164   // This is required as the mhlo::ReturnOp is used to mark the end of a
165   // block for regions nested inside of a operations (MLIR ReturnOp cannot be
166   // nested within an non-function region).
167   for (auto& block : while_op.cond()) {
168     auto new_block = mapper.lookup(&block);
169 
170     auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
171     if (!return_op) continue;
172     builder.setInsertionPointToEnd(new_block);
173 
174     auto return_value = return_op.getOperand(0);
175     auto cond_value =
176         builder.create<mlir::tensor::ExtractOp>(loc, return_value);
177 
178     // Get the body block arguments.
179     llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
180                                                cond_block->args_end());
181     builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
182                                        successor_args, tail_block,
183                                        successor_args);
184     return_op.erase();
185   }
186 
187   // Updates the body blocks by replace the return op with an branch to the
188   // conditional block. This changes the block below:
189   //   ^body(%0):
190   //     <inlined body block>
191   //    "mhlo".return(%1)
192   //
193   //  Into:
194   //   ^body(%0):
195   //     <inlined body block>
196   //     br ^cond(%0) // Branch.
197   for (auto& block : while_op.body()) {
198     auto new_block = mapper.lookup(&block);
199     auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
200     if (!return_op) continue;
201     builder.setInsertionPointToEnd(new_block);
202     builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
203     return_op.erase();
204   }
205 
206   // Erase the original while loop.
207   // TODO(jpienaar): Support multi-operand while op.
208   tail_block->addArgument(while_op.arg().getType()[0]);
209   while_op.getResult(0).replaceAllUsesWith(tail_block->getArgument(0));
210   op_inst->erase();
211 
212   return success();
213 }
214 
runOnFunction()215 void LegalizeControlFlowPass::runOnFunction() {
216   auto func = getFunction();
217   llvm::SmallVector<IfOp, 4> if_ops;
218   func.walk([&](IfOp op) { if_ops.push_back(op); });
219 
220   for (auto& op : if_ops) {
221     if (failed(LowerIfOp(op))) return signalPassFailure();
222   }
223 
224   llvm::SmallVector<WhileOp, 4> while_ops;
225   func.walk([&](WhileOp op) { while_ops.push_back(op); });
226 
227   for (auto& op : while_ops) {
228     if (failed(LowerWhileOp(op))) return signalPassFailure();
229   }
230 }
231 }  // namespace
232 }  // namespace mhlo
233 }  // namespace mlir
234 
235 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeControlFlowPass()236 mlir::mhlo::createLegalizeControlFlowPass() {
237   return std::make_unique<LegalizeControlFlowPass>();
238 }
239