1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/Support/Casting.h"
17 #include "llvm/Support/Debug.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30
31 #define DEBUG_TYPE "mhlo-control-flow-to-scf"
32
33 namespace mlir {
34 namespace mhlo {
35
36 namespace {
37
38 /// Convert MHLO While to SCF.
39 void MatchAndRewrite(WhileOp whileOp);
40
41 /// Pass that converts MHLO control flow to SCF.
42 class ControlFlowToScfPass
43 : public LegalizeControlFlowToScfPassBase<ControlFlowToScfPass> {
getDependentDialects(DialectRegistry & registry) const44 void getDependentDialects(DialectRegistry& registry) const override {
45 registry.insert<scf::SCFDialect>();
46 }
runOnFunction()47 void runOnFunction() override {
48 getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
49 }
50 };
51
52 // TODO(jpienaar): Look into reformulating as a pattern.
MatchAndRewrite(WhileOp whileOp)53 void MatchAndRewrite(WhileOp whileOp) {
54 // TODO(jpienaar): Supports multi-operand while op.
55 if (whileOp.arg().size() != 1) return;
56
57 // Handle pattern:
58 // x = start
59 // step = ...
60 // limit = ...
61 // while (x < limit) { ... x += step; }
62
63 // Only handling multi value while loops at the moment.
64 // TODO(jpienaar): Support multi-operand while op.
65 auto tupleOp = whileOp.getOperand(0).getDefiningOp<TupleOp>();
66 if (!tupleOp) return;
67 auto bodyReturn = whileOp.body()
68 .front()
69 .getTerminator()
70 ->getOperand(0)
71 .getDefiningOp<mhlo::TupleOp>();
72 // Note: due to the shape restrictions on While, if the operand to While is a
73 // tuple, then so is the return type of the body. But the verifier isn't
74 // checking that at the moment, so just bail out here if this doesn't hold.
75 if (!bodyReturn) return;
76
77 Value result = whileOp.cond().front().getTerminator()->getOperand(0);
78 // TODO(jpienaar): Expand to handle more than simple case with LT compare and
79 // constant step.
80 auto cmp = result.getDefiningOp<mhlo::CompareOp>();
81 if (!cmp || cmp.comparison_direction() != "LT") return;
82
83 const int kConstant = -1;
84 auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
85 if (matchPattern(val, m_Constant())) return {val, kConstant};
86 // If it is defined by a tuple, then the tuple has to have been fed in and
87 // the external value is captured.
88 if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
89 if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
90 int index = gte.index();
91 return {tupleOp.getOperand(index), index};
92 }
93 return {nullptr, 0};
94 };
95
96 using ValueIndex = std::pair<Value, int>;
97 ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
98 ValueIndex max = getValueAndIndex(cmp.rhs());
99 if (!loopIndVar.first || !max.first) return;
100 auto add =
101 bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
102 if (!add) return;
103 ValueIndex step = getValueAndIndex(add.rhs());
104 if (step.second != kConstant || !step.first) return;
105
106 // Only handle case where tuple isn't propagated as is for now.
107 // TODO(jpienaar): Remove this when a tuple is also created inside the loop
108 // to propagate.
109 for (auto* use : whileOp.body().front().getArgument(0).getUsers())
110 if (!isa<GetTupleElementOp>(use)) return;
111
112 LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
113 llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
114 << max.second << " step = " << step.second << "\n";
115 llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
116 << max.first << " step = " << step.first << "\n";);
117 OpBuilder b(whileOp);
118 // Inputs to new for loop.
119 llvm::SmallVector<Value, 4> input;
120 input.reserve(tupleOp.getNumOperands());
121 for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
122 input.push_back(r);
123 for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
124 input.push_back(r);
125
126 auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
127 auto getAsIndex = [&](Value val) {
128 auto loc = whileOp.getLoc();
129 return b.create<tensor::ExtractOp>(
130 loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
131 };
132
133 // SCF for uses index type, so converted these.
134 auto forloopIndVar = getAsIndex(loopIndVar.first);
135 auto forMax = getAsIndex(max.first);
136 auto forStep = getAsIndex(step.first);
137 auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
138 forMax, forStep, input);
139 // Transfer the body without the block arguments.
140 forOp.getLoopBody().front().getOperations().splice(
141 forOp.getLoopBody().front().getOperations().end(),
142 whileOp.body().front().getOperations());
143
144 b.setInsertionPointToStart(&forOp.getLoopBody().front());
145 auto loopIndVarElType =
146 loopIndVar.first.getType().cast<ShapedType>().getElementType();
147 Value indVar = b.create<SplatOp>(
148 whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
149 b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
150 forOp.getInductionVar()));
151 // Update all block argument users to the SCF For args.
152 for (auto* use :
153 llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
154 // TODO(jpienaar): Expand here too when we allow using the tuple in the
155 // loop.
156 auto gte = cast<GetTupleElementOp>(use);
157 // If the loop induction var, then refer to the loop induction variable as
158 // this operand is not updated.
159 if (gte.index() == loopIndVar.second) {
160 use->getResult(0).replaceAllUsesWith(indVar);
161 use->erase();
162 continue;
163 }
164 int index = gte.index();
165 // If after the loop induction variable, then decrement as we don't include
166 // the loop induction variable in the for iter operands.
167 if (index > loopIndVar.second) --index;
168 use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
169 use->erase();
170 }
171
172 // Create new yield op without induction var update.
173 SmallVector<Value, 4> newYieldOps;
174 newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
175 for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
176 newYieldOps.push_back(r);
177 for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
178 newYieldOps.push_back(r);
179 // Delete return & tuple op.
180 forOp.getLoopBody().front().back().erase();
181 forOp.getLoopBody().front().back().erase();
182 b.setInsertionPointToEnd(&forOp.getLoopBody().front());
183 b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
184
185 // Recombine output tuple with max value of induction variable.
186 llvm::SmallVector<Value, 4> loopOut;
187 loopOut.reserve(forOp.getNumResults() + 1);
188 for (auto r : forOp.getResults().take_front(loopIndVar.second))
189 loopOut.push_back(r);
190 loopOut.push_back(max.first);
191 for (auto r : forOp.getResults().drop_front(loopIndVar.second))
192 loopOut.push_back(r);
193 b.setInsertionPoint(whileOp);
194 auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
195 whileOp.replaceAllUsesWith(newRes.getOperation());
196 whileOp.erase();
197 }
198
199 } // anonymous namespace
200
createControlFlowToScfPass()201 std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
202 return std::make_unique<ControlFlowToScfPass>();
203 }
204
205 } // namespace mhlo
206 } // namespace mlir
207