• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/Support/Casting.h"
17 #include "llvm/Support/Debug.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30 
31 #define DEBUG_TYPE "mhlo-control-flow-to-scf"
32 
33 namespace mlir {
34 namespace mhlo {
35 
36 namespace {
37 
38 /// Convert MHLO While to SCF.
39 void MatchAndRewrite(WhileOp whileOp);
40 
41 /// Pass that converts MHLO control flow to SCF.
42 class ControlFlowToScfPass
43     : public LegalizeControlFlowToScfPassBase<ControlFlowToScfPass> {
getDependentDialects(DialectRegistry & registry) const44   void getDependentDialects(DialectRegistry& registry) const override {
45     registry.insert<scf::SCFDialect>();
46   }
runOnFunction()47   void runOnFunction() override {
48     getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
49   }
50 };
51 
52 // TODO(jpienaar): Look into reformulating as a pattern.
MatchAndRewrite(WhileOp whileOp)53 void MatchAndRewrite(WhileOp whileOp) {
54   // TODO(jpienaar): Supports multi-operand while op.
55   if (whileOp.arg().size() != 1) return;
56 
57   // Handle pattern:
58   //   x = start
59   //   step = ...
60   //   limit = ...
61   //   while (x < limit) { ... x += step; }
62 
63   // Only handling multi value while loops at the moment.
64   // TODO(jpienaar): Support multi-operand while op.
65   auto tupleOp = whileOp.getOperand(0).getDefiningOp<TupleOp>();
66   if (!tupleOp) return;
67   auto bodyReturn = whileOp.body()
68                         .front()
69                         .getTerminator()
70                         ->getOperand(0)
71                         .getDefiningOp<mhlo::TupleOp>();
72   // Note: due to the shape restrictions on While, if the operand to While is a
73   // tuple, then so is the return type of the body. But the verifier isn't
74   // checking that at the moment, so just bail out here if this doesn't hold.
75   if (!bodyReturn) return;
76 
77   Value result = whileOp.cond().front().getTerminator()->getOperand(0);
78   // TODO(jpienaar): Expand to handle more than simple case with LT compare and
79   // constant step.
80   auto cmp = result.getDefiningOp<mhlo::CompareOp>();
81   if (!cmp || cmp.comparison_direction() != "LT") return;
82 
83   const int kConstant = -1;
84   auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
85     if (matchPattern(val, m_Constant())) return {val, kConstant};
86     // If it is defined by a tuple, then the tuple has to have been fed in and
87     // the external value is captured.
88     if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
89       if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
90       int index = gte.index();
91       return {tupleOp.getOperand(index), index};
92     }
93     return {nullptr, 0};
94   };
95 
96   using ValueIndex = std::pair<Value, int>;
97   ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
98   ValueIndex max = getValueAndIndex(cmp.rhs());
99   if (!loopIndVar.first || !max.first) return;
100   auto add =
101       bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
102   if (!add) return;
103   ValueIndex step = getValueAndIndex(add.rhs());
104   if (step.second != kConstant || !step.first) return;
105 
106   // Only handle case where tuple isn't propagated as is for now.
107   // TODO(jpienaar): Remove this when a tuple is also created inside the loop
108   // to propagate.
109   for (auto* use : whileOp.body().front().getArgument(0).getUsers())
110     if (!isa<GetTupleElementOp>(use)) return;
111 
112   LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
113              llvm::dbgs() << "  loopIndVar = " << loopIndVar.second << " max = "
114                           << max.second << " step = " << step.second << "\n";
115              llvm::dbgs() << "  loopIndVar = " << loopIndVar.first << " max = "
116                           << max.first << " step = " << step.first << "\n";);
117   OpBuilder b(whileOp);
118   // Inputs to new for loop.
119   llvm::SmallVector<Value, 4> input;
120   input.reserve(tupleOp.getNumOperands());
121   for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
122     input.push_back(r);
123   for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
124     input.push_back(r);
125 
126   auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
127   auto getAsIndex = [&](Value val) {
128     auto loc = whileOp.getLoc();
129     return b.create<tensor::ExtractOp>(
130         loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
131   };
132 
133   // SCF for uses index type, so converted these.
134   auto forloopIndVar = getAsIndex(loopIndVar.first);
135   auto forMax = getAsIndex(max.first);
136   auto forStep = getAsIndex(step.first);
137   auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
138                                           forMax, forStep, input);
139   // Transfer the body without the block arguments.
140   forOp.getLoopBody().front().getOperations().splice(
141       forOp.getLoopBody().front().getOperations().end(),
142       whileOp.body().front().getOperations());
143 
144   b.setInsertionPointToStart(&forOp.getLoopBody().front());
145   auto loopIndVarElType =
146       loopIndVar.first.getType().cast<ShapedType>().getElementType();
147   Value indVar = b.create<SplatOp>(
148       whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
149       b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
150                             forOp.getInductionVar()));
151   // Update all block argument users to the SCF For args.
152   for (auto* use :
153        llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
154     // TODO(jpienaar): Expand here too when we allow using the tuple in the
155     // loop.
156     auto gte = cast<GetTupleElementOp>(use);
157     // If the loop induction var, then refer to the loop induction variable as
158     // this operand is not updated.
159     if (gte.index() == loopIndVar.second) {
160       use->getResult(0).replaceAllUsesWith(indVar);
161       use->erase();
162       continue;
163     }
164     int index = gte.index();
165     // If after the loop induction variable, then decrement as we don't include
166     // the loop induction variable in the for iter operands.
167     if (index > loopIndVar.second) --index;
168     use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
169     use->erase();
170   }
171 
172   // Create new yield op without induction var update.
173   SmallVector<Value, 4> newYieldOps;
174   newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
175   for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
176     newYieldOps.push_back(r);
177   for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
178     newYieldOps.push_back(r);
179   // Delete return & tuple op.
180   forOp.getLoopBody().front().back().erase();
181   forOp.getLoopBody().front().back().erase();
182   b.setInsertionPointToEnd(&forOp.getLoopBody().front());
183   b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
184 
185   // Recombine output tuple with max value of induction variable.
186   llvm::SmallVector<Value, 4> loopOut;
187   loopOut.reserve(forOp.getNumResults() + 1);
188   for (auto r : forOp.getResults().take_front(loopIndVar.second))
189     loopOut.push_back(r);
190   loopOut.push_back(max.first);
191   for (auto r : forOp.getResults().drop_front(loopIndVar.second))
192     loopOut.push_back(r);
193   b.setInsertionPoint(whileOp);
194   auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
195   whileOp.replaceAllUsesWith(newRes.getOperation());
196   whileOp.erase();
197 }
198 
199 }  // anonymous namespace
200 
createControlFlowToScfPass()201 std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
202   return std::make_unique<ControlFlowToScfPass>();
203 }
204 
205 }  // namespace mhlo
206 }  // namespace mlir
207