• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils.h"
14 
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/RegionUtils.h"
20 
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
cloneWithNewYields(OpBuilder & b,scf::ForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,bool replaceLoopResults)25 scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
26                                     ValueRange newIterOperands,
27                                     ValueRange newYieldedValues,
28                                     bool replaceLoopResults) {
29   assert(newIterOperands.size() == newYieldedValues.size() &&
30          "newIterOperands must be of the same size as newYieldedValues");
31 
32   // Create a new loop before the existing one, with the extra operands.
33   OpBuilder::InsertionGuard g(b);
34   b.setInsertionPoint(loop);
35   auto operands = llvm::to_vector<4>(loop.getIterOperands());
36   operands.append(newIterOperands.begin(), newIterOperands.end());
37   scf::ForOp newLoop =
38       b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
39                            loop.step(), operands);
40 
41   auto &loopBody = *loop.getBody();
42   auto &newLoopBody = *newLoop.getBody();
43   // Clone / erase the yield inside the original loop to both:
44   //   1. augment its operands with the newYieldedValues.
45   //   2. automatically apply the BlockAndValueMapping on its operand
46   auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
47   b.setInsertionPoint(yield);
48   auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
49   yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
50   auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
51 
52   // Clone the loop body with remaps.
53   BlockAndValueMapping bvm;
54   // a. remap the induction variable.
55   bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
56   // b. remap the BB args.
57   bvm.map(loopBody.getArguments(),
58           newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
59   // c. remap the iter args.
60   bvm.map(newIterOperands,
61           newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
62   b.setInsertionPointToStart(&newLoopBody);
63   // Skip the original yield terminator which does not have enough operands.
64   for (auto &o : loopBody.without_terminator())
65     b.clone(o, bvm);
66 
67   // Replace `loop`'s results if requested.
68   if (replaceLoopResults) {
69     for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
70                                                     loop.getNumResults())))
71       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
72   }
73 
74   // TODO: this is unsafe in the context of a PatternRewrite.
75   newYield.erase();
76 
77   return newLoop;
78 }
79 
outlineIfOp(OpBuilder & b,scf::IfOp ifOp,FuncOp * thenFn,StringRef thenFnName,FuncOp * elseFn,StringRef elseFnName)80 void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
81                        StringRef thenFnName, FuncOp *elseFn,
82                        StringRef elseFnName) {
83   Location loc = ifOp.getLoc();
84   MLIRContext *ctx = ifOp.getContext();
85   auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
86     assert(!funcName.empty() && "Expected function name for outlining");
87     assert(ifOrElseRegion.getBlocks().size() <= 1 &&
88            "Expected at most one block");
89 
90     // Outline before current function.
91     OpBuilder::InsertionGuard g(b);
92     b.setInsertionPoint(ifOp->getParentOfType<FuncOp>());
93 
94     llvm::SetVector<Value> captures;
95     getUsedValuesDefinedAbove(ifOrElseRegion, captures);
96 
97     ValueRange values(captures.getArrayRef());
98     FunctionType type =
99         FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
100     auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
101     b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
102     BlockAndValueMapping bvm;
103     for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
104       bvm.map(std::get<0>(it), std::get<1>(it));
105     for (Operation &op : ifOrElseRegion.front().without_terminator())
106       b.clone(op, bvm);
107 
108     Operation *term = ifOrElseRegion.front().getTerminator();
109     SmallVector<Value, 4> terminatorOperands;
110     for (auto op : term->getOperands())
111       terminatorOperands.push_back(bvm.lookup(op));
112     b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
113 
114     ifOrElseRegion.front().clear();
115     b.setInsertionPointToEnd(&ifOrElseRegion.front());
116     Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
117     b.create<scf::YieldOp>(loc, call->getResults());
118     return outlinedFunc;
119   };
120 
121   if (thenFn && !ifOp.thenRegion().empty())
122     *thenFn = outline(ifOp.thenRegion(), thenFnName);
123   if (elseFn && !ifOp.elseRegion().empty())
124     *elseFn = outline(ifOp.elseRegion(), elseFnName);
125 }
126