1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/Support/Casting.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Support/LLVM.h"
28
29 #define DEBUG_TYPE "mhlo-control-flow-to-scf"
30
31 namespace mlir {
32 namespace mhlo {
33
34 namespace {
35
36 /// Convert MHLO While to SCF.
37 void MatchAndRewrite(WhileOp whileOp);
38
39 /// Pass that converts MHLO control flow to SCF.
40 class ControlFlowToScfPass
41 : public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const42 void getDependentDialects(DialectRegistry& registry) const override {
43 registry.insert<scf::SCFDialect>();
44 }
runOnFunction()45 void runOnFunction() override {
46 getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
47 }
48 };
49
50 // TODO(jpienaar): Look into reformulating as a pattern.
MatchAndRewrite(WhileOp whileOp)51 void MatchAndRewrite(WhileOp whileOp) {
52 // Handle pattern:
53 // x = start
54 // step = ...
55 // limit = ...
56 // while (x < limit) { ... x += step; }
57
58 // Only handling multi value while loops at the moment.
59 auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
60 if (!tupleOp) return;
61 auto bodyReturn = whileOp.body()
62 .front()
63 .getTerminator()
64 ->getOperand(0)
65 .getDefiningOp<mhlo::TupleOp>();
66 // Note: due to the shape restrictions on While, if the operand to While is a
67 // tuple, then so is the return type of the body. But the verifier isn't
68 // checking that at the moment, so just bail out here if this doesn't hold.
69 if (!bodyReturn) return;
70
71 Value result = whileOp.cond().front().getTerminator()->getOperand(0);
72 // TODO(jpienaar): Expand to handle more than simple case with LT compare and
73 // constant step.
74 auto cmp = result.getDefiningOp<mhlo::CompareOp>();
75 if (!cmp || cmp.comparison_direction() != "LT") return;
76
77 const int kConstant = -1;
78 auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
79 if (matchPattern(val, m_Constant())) return {val, kConstant};
80 // If it is defined by a tuple, then the tuple has to have been fed in and
81 // the external value is captured.
82 if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
83 if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
84 int index = gte.index();
85 return {tupleOp.getOperand(index), index};
86 }
87 return {nullptr, 0};
88 };
89
90 using ValueIndex = std::pair<Value, int>;
91 ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
92 ValueIndex max = getValueAndIndex(cmp.rhs());
93 if (!loopIndVar.first || !max.first) return;
94 auto add =
95 bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
96 if (!add) return;
97 ValueIndex step = getValueAndIndex(add.rhs());
98 if (step.second != kConstant || !step.first) return;
99
100 // Only handle case where tuple isn't propagated as is for now.
101 // TODO(jpienaar): Remove this when a tuple is also created inside the loop
102 // to propagate.
103 for (auto* use : whileOp.body().front().getArgument(0).getUsers())
104 if (!isa<GetTupleElementOp>(use)) return;
105
106 LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
107 llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
108 << max.second << " step = " << step.second << "\n";
109 llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
110 << max.first << " step = " << step.first << "\n";);
111 OpBuilder b(whileOp);
112 // Inputs to new for loop.
113 llvm::SmallVector<Value, 4> input;
114 input.reserve(tupleOp.getNumOperands());
115 for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
116 input.push_back(r);
117 for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
118 input.push_back(r);
119
120 auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
121 auto getAsIndex = [&](Value val) {
122 auto loc = whileOp.getLoc();
123 return b.create<tensor::ExtractOp>(
124 loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
125 };
126
127 // SCF for uses index type, so converted these.
128 auto forloopIndVar = getAsIndex(loopIndVar.first);
129 auto forMax = getAsIndex(max.first);
130 auto forStep = getAsIndex(step.first);
131 auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
132 forMax, forStep, input);
133 // Transfer the body without the block arguments.
134 forOp.getLoopBody().front().getOperations().splice(
135 forOp.getLoopBody().front().getOperations().end(),
136 whileOp.body().front().getOperations());
137
138 b.setInsertionPointToStart(&forOp.getLoopBody().front());
139 auto loopIndVarElType =
140 loopIndVar.first.getType().cast<ShapedType>().getElementType();
141 Value indVar = b.create<SplatOp>(
142 whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
143 b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
144 forOp.getInductionVar()));
145 // Update all block argument users to the SCF For args.
146 for (auto* use :
147 llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
148 // TODO(jpienaar): Expand here too when we allow using the tuple in the
149 // loop.
150 auto gte = cast<GetTupleElementOp>(use);
151 // If the loop induction var, then refer to the loop induction variable as
152 // this operand is not updated.
153 if (gte.index() == loopIndVar.second) {
154 use->getResult(0).replaceAllUsesWith(indVar);
155 use->erase();
156 continue;
157 }
158 int index = gte.index();
159 // If after the loop induction variable, then decrement as we don't include
160 // the loop induction variable in the for iter operands.
161 if (index > loopIndVar.second) --index;
162 use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
163 use->erase();
164 }
165
166 // Create new yield op without induction var update.
167 SmallVector<Value, 4> newYieldOps;
168 newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
169 for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
170 newYieldOps.push_back(r);
171 for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
172 newYieldOps.push_back(r);
173 // Delete return & tuple op.
174 forOp.getLoopBody().front().back().erase();
175 forOp.getLoopBody().front().back().erase();
176 b.setInsertionPointToEnd(&forOp.getLoopBody().front());
177 b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
178
179 // Recombine output tuple with max value of induction variable.
180 llvm::SmallVector<Value, 4> loopOut;
181 loopOut.reserve(forOp.getNumResults() + 1);
182 for (auto r : forOp.getResults().take_front(loopIndVar.second))
183 loopOut.push_back(r);
184 loopOut.push_back(max.first);
185 for (auto r : forOp.getResults().drop_front(loopIndVar.second))
186 loopOut.push_back(r);
187 b.setInsertionPoint(whileOp);
188 auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
189 whileOp.replaceAllUsesWith(newRes.getOperation());
190 whileOp.erase();
191 }
192
193 } // anonymous namespace
194
createControlFlowToScfPass()195 std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
196 return std::make_unique<ControlFlowToScfPass>();
197 }
198
199 } // namespace mhlo
200 } // namespace mlir
201