1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SCF/Utils.h"
14
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/RegionUtils.h"
20
21 #include "llvm/ADT/SetVector.h"
22
23 using namespace mlir;
24
cloneWithNewYields(OpBuilder & b,scf::ForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,bool replaceLoopResults)25 scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
26 ValueRange newIterOperands,
27 ValueRange newYieldedValues,
28 bool replaceLoopResults) {
29 assert(newIterOperands.size() == newYieldedValues.size() &&
30 "newIterOperands must be of the same size as newYieldedValues");
31
32 // Create a new loop before the existing one, with the extra operands.
33 OpBuilder::InsertionGuard g(b);
34 b.setInsertionPoint(loop);
35 auto operands = llvm::to_vector<4>(loop.getIterOperands());
36 operands.append(newIterOperands.begin(), newIterOperands.end());
37 scf::ForOp newLoop =
38 b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
39 loop.step(), operands);
40
41 auto &loopBody = *loop.getBody();
42 auto &newLoopBody = *newLoop.getBody();
43 // Clone / erase the yield inside the original loop to both:
44 // 1. augment its operands with the newYieldedValues.
45 // 2. automatically apply the BlockAndValueMapping on its operand
46 auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
47 b.setInsertionPoint(yield);
48 auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
49 yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
50 auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
51
52 // Clone the loop body with remaps.
53 BlockAndValueMapping bvm;
54 // a. remap the induction variable.
55 bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
56 // b. remap the BB args.
57 bvm.map(loopBody.getArguments(),
58 newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
59 // c. remap the iter args.
60 bvm.map(newIterOperands,
61 newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
62 b.setInsertionPointToStart(&newLoopBody);
63 // Skip the original yield terminator which does not have enough operands.
64 for (auto &o : loopBody.without_terminator())
65 b.clone(o, bvm);
66
67 // Replace `loop`'s results if requested.
68 if (replaceLoopResults) {
69 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
70 loop.getNumResults())))
71 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
72 }
73
74 // TODO: this is unsafe in the context of a PatternRewrite.
75 newYield.erase();
76
77 return newLoop;
78 }
79
outlineIfOp(OpBuilder & b,scf::IfOp ifOp,FuncOp * thenFn,StringRef thenFnName,FuncOp * elseFn,StringRef elseFnName)80 void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
81 StringRef thenFnName, FuncOp *elseFn,
82 StringRef elseFnName) {
83 Location loc = ifOp.getLoc();
84 MLIRContext *ctx = ifOp.getContext();
85 auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
86 assert(!funcName.empty() && "Expected function name for outlining");
87 assert(ifOrElseRegion.getBlocks().size() <= 1 &&
88 "Expected at most one block");
89
90 // Outline before current function.
91 OpBuilder::InsertionGuard g(b);
92 b.setInsertionPoint(ifOp->getParentOfType<FuncOp>());
93
94 llvm::SetVector<Value> captures;
95 getUsedValuesDefinedAbove(ifOrElseRegion, captures);
96
97 ValueRange values(captures.getArrayRef());
98 FunctionType type =
99 FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
100 auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
101 b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
102 BlockAndValueMapping bvm;
103 for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
104 bvm.map(std::get<0>(it), std::get<1>(it));
105 for (Operation &op : ifOrElseRegion.front().without_terminator())
106 b.clone(op, bvm);
107
108 Operation *term = ifOrElseRegion.front().getTerminator();
109 SmallVector<Value, 4> terminatorOperands;
110 for (auto op : term->getOperands())
111 terminatorOperands.push_back(bvm.lookup(op));
112 b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
113
114 ifOrElseRegion.front().clear();
115 b.setInsertionPointToEnd(&ifOrElseRegion.front());
116 Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
117 b.create<scf::YieldOp>(loc, call->getResults());
118 return outlinedFunc;
119 };
120
121 if (thenFn && !ifOp.thenRegion().empty())
122 *thenFn = outline(ifOp.thenRegion(), thenFnName);
123 if (elseFn && !ifOp.elseRegion().empty())
124 *elseFn = outline(ifOp.elseRegion(), elseFnName);
125 }
126