• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/Support/Casting.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Support/LLVM.h"
28 
29 #define DEBUG_TYPE "mhlo-control-flow-to-scf"
30 
31 namespace mlir {
32 namespace mhlo {
33 
34 namespace {
35 
36 /// Convert MHLO While to SCF.
37 void MatchAndRewrite(WhileOp whileOp);
38 
39 /// Pass that converts MHLO control flow to SCF.
40 class ControlFlowToScfPass
41     : public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const42   void getDependentDialects(DialectRegistry& registry) const override {
43     registry.insert<scf::SCFDialect>();
44   }
runOnFunction()45   void runOnFunction() override {
46     getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
47   }
48 };
49 
50 // TODO(jpienaar): Look into reformulating as a pattern.
MatchAndRewrite(WhileOp whileOp)51 void MatchAndRewrite(WhileOp whileOp) {
52   // Handle pattern:
53   //   x = start
54   //   step = ...
55   //   limit = ...
56   //   while (x < limit) { ... x += step; }
57 
58   // Only handling multi value while loops at the moment.
59   auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
60   if (!tupleOp) return;
61   auto bodyReturn = whileOp.body()
62                         .front()
63                         .getTerminator()
64                         ->getOperand(0)
65                         .getDefiningOp<mhlo::TupleOp>();
66   // Note: due to the shape restrictions on While, if the operand to While is a
67   // tuple, then so is the return type of the body. But the verifier isn't
68   // checking that at the moment, so just bail out here if this doesn't hold.
69   if (!bodyReturn) return;
70 
71   Value result = whileOp.cond().front().getTerminator()->getOperand(0);
72   // TODO(jpienaar): Expand to handle more than simple case with LT compare and
73   // constant step.
74   auto cmp = result.getDefiningOp<mhlo::CompareOp>();
75   if (!cmp || cmp.comparison_direction() != "LT") return;
76 
77   const int kConstant = -1;
78   auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
79     if (matchPattern(val, m_Constant())) return {val, kConstant};
80     // If it is defined by a tuple, then the tuple has to have been fed in and
81     // the external value is captured.
82     if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
83       if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
84       int index = gte.index();
85       return {tupleOp.getOperand(index), index};
86     }
87     return {nullptr, 0};
88   };
89 
90   using ValueIndex = std::pair<Value, int>;
91   ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
92   ValueIndex max = getValueAndIndex(cmp.rhs());
93   if (!loopIndVar.first || !max.first) return;
94   auto add =
95       bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
96   if (!add) return;
97   ValueIndex step = getValueAndIndex(add.rhs());
98   if (step.second != kConstant || !step.first) return;
99 
100   // Only handle case where tuple isn't propagated as is for now.
101   // TODO(jpienaar): Remove this when a tuple is also created inside the loop
102   // to propagate.
103   for (auto* use : whileOp.body().front().getArgument(0).getUsers())
104     if (!isa<GetTupleElementOp>(use)) return;
105 
106   LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
107              llvm::dbgs() << "  loopIndVar = " << loopIndVar.second << " max = "
108                           << max.second << " step = " << step.second << "\n";
109              llvm::dbgs() << "  loopIndVar = " << loopIndVar.first << " max = "
110                           << max.first << " step = " << step.first << "\n";);
111   OpBuilder b(whileOp);
112   // Inputs to new for loop.
113   llvm::SmallVector<Value, 4> input;
114   input.reserve(tupleOp.getNumOperands());
115   for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
116     input.push_back(r);
117   for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
118     input.push_back(r);
119 
120   auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
121   auto getAsIndex = [&](Value val) {
122     auto loc = whileOp.getLoc();
123     return b.create<tensor::ExtractOp>(
124         loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
125   };
126 
127   // SCF for uses index type, so converted these.
128   auto forloopIndVar = getAsIndex(loopIndVar.first);
129   auto forMax = getAsIndex(max.first);
130   auto forStep = getAsIndex(step.first);
131   auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
132                                           forMax, forStep, input);
133   // Transfer the body without the block arguments.
134   forOp.getLoopBody().front().getOperations().splice(
135       forOp.getLoopBody().front().getOperations().end(),
136       whileOp.body().front().getOperations());
137 
138   b.setInsertionPointToStart(&forOp.getLoopBody().front());
139   auto loopIndVarElType =
140       loopIndVar.first.getType().cast<ShapedType>().getElementType();
141   Value indVar = b.create<SplatOp>(
142       whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
143       b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
144                             forOp.getInductionVar()));
145   // Update all block argument users to the SCF For args.
146   for (auto* use :
147        llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
148     // TODO(jpienaar): Expand here too when we allow using the tuple in the
149     // loop.
150     auto gte = cast<GetTupleElementOp>(use);
151     // If the loop induction var, then refer to the loop induction variable as
152     // this operand is not updated.
153     if (gte.index() == loopIndVar.second) {
154       use->getResult(0).replaceAllUsesWith(indVar);
155       use->erase();
156       continue;
157     }
158     int index = gte.index();
159     // If after the loop induction variable, then decrement as we don't include
160     // the loop induction variable in the for iter operands.
161     if (index > loopIndVar.second) --index;
162     use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
163     use->erase();
164   }
165 
166   // Create new yield op without induction var update.
167   SmallVector<Value, 4> newYieldOps;
168   newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
169   for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
170     newYieldOps.push_back(r);
171   for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
172     newYieldOps.push_back(r);
173   // Delete return & tuple op.
174   forOp.getLoopBody().front().back().erase();
175   forOp.getLoopBody().front().back().erase();
176   b.setInsertionPointToEnd(&forOp.getLoopBody().front());
177   b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
178 
179   // Recombine output tuple with max value of induction variable.
180   llvm::SmallVector<Value, 4> loopOut;
181   loopOut.reserve(forOp.getNumResults() + 1);
182   for (auto r : forOp.getResults().take_front(loopIndVar.second))
183     loopOut.push_back(r);
184   loopOut.push_back(max.first);
185   for (auto r : forOp.getResults().drop_front(loopIndVar.second))
186     loopOut.push_back(r);
187   b.setInsertionPoint(whileOp);
188   auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
189   whileOp.replaceAllUsesWith(newRes.getOperation());
190   whileOp.erase();
191 }
192 
193 }  // anonymous namespace
194 
createControlFlowToScfPass()195 std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
196   return std::make_unique<ControlFlowToScfPass>();
197 }
198 
199 }  // namespace mhlo
200 }  // namespace mlir
201