1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
25 #include "mlir/Dialect/SCF/SCF.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Interfaces/ViewLikeInterface.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31
32 namespace mlir {
33 namespace lmhlo {
34 namespace {
35
36 using linalg::LinalgOp;
37
38 class LhloFuseLinalgPass
39 : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const40 void getDependentDialects(DialectRegistry& registry) const override {
41 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
42 }
43
44 public:
45 LhloFuseLinalgPass() = default;
LhloFuseLinalgPass(const LhloFuseLinalgPass &)46 LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
LhloFuseLinalgPass(bool use_parallel_loops,llvm::ArrayRef<unsigned> tile_sizes)47 LhloFuseLinalgPass(bool use_parallel_loops,
48 llvm::ArrayRef<unsigned> tile_sizes) {
49 tile_sizes_ = tile_sizes;
50 use_parallel_loops_.setValue(use_parallel_loops);
51 }
52
runOnFunction()53 void runOnFunction() override {
54 auto func = getFunction();
55
56 // TODO(pifon): Remove assumption that the function has a single block.
57 if (!llvm::hasSingleElement(func)) {
58 emitError(func.getLoc(), "The function needs to have a single block.");
59 signalPassFailure();
60 return;
61 }
62
63 // The fusion in Linalg is currently possible only when the consumer op is
64 // tiled. In order to greedily fuse the ops, we have to start from the tiled
65 // root linalg ops, i.e. linalg ops that write to output buffers of the
66 // function or are returned in case of escaping allocations.
67 llvm::SmallDenseSet<Value> result_buffers;
68 for (auto func_arg : func.getArguments()) {
69 result_buffers.insert(func_arg);
70 }
71 for (auto& block : func) {
72 auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
73 if (!returnOp) continue;
74 for (auto operand : returnOp.getOperands()) {
75 result_buffers.insert(operand);
76 }
77 }
78 // Resolve aliasing operations (like casts) on the result to identify
79 // results. This only handles escaping results.
80 // TODO(herhut): Use BufferizeAliasAnalysis for this.
81 llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
82 result_buffers.end());
83 while (!worklist.empty()) {
84 Value result = worklist.pop_back_val();
85 auto definingOp = result.getDefiningOp();
86 if (!definingOp) {
87 continue;
88 }
89
90 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
91 auto alias = viewLike.getViewSource();
92 if (result_buffers.insert(alias).second) {
93 worklist.push_back(alias);
94 }
95 continue;
96 }
97
98 if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
99 auto alias = tensor_load.memref();
100 if (result_buffers.insert(alias).second) {
101 worklist.push_back(alias);
102 }
103 continue;
104 }
105
106 if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
107 auto alias = tensor_to_memref.tensor();
108 if (result_buffers.insert(alias).second) {
109 worklist.push_back(alias);
110 }
111 continue;
112 }
113
114 if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
115 auto alias = tensor_cast.source();
116 if (result_buffers.insert(alias).second) {
117 worklist.push_back(alias);
118 }
119 continue;
120 }
121
122 if (auto regionInterface =
123 dyn_cast<RegionBranchOpInterface>(definingOp)) {
124 for (Region& region : regionInterface.getOperation()->getRegions()) {
125 // Only consider regions that can return to the parent region.
126 SmallVector<RegionSuccessor, 2> successorRegions;
127 regionInterface.getSuccessorRegions(region.getRegionNumber(),
128 successorRegions);
129 if (llvm::none_of(successorRegions, [&](auto successorRegion) {
130 return successorRegion.isParent();
131 }))
132 continue;
133
134 // Iterate over all immediate terminators and record the values
135 // corresponding to result_buffers of interest.
136 for (Block& block : region) {
137 if (block.empty()) continue;
138 Operation& operation = block.back();
139 if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
140 auto idx = result.dyn_cast<OpResult>().getResultNumber();
141 if (result_buffers.insert(operation.getOperand(idx)).second) {
142 worklist.push_back(operation.getOperand(idx));
143 }
144 }
145 }
146 }
147 }
148
149 MLIRContext* ctx = func.getContext();
150 OpBuilder b(func);
151 func.walk([&](linalg::GenericOp generic_op) {
152 SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
153 tile_sizes_.end());
154 if (tile_sizes.empty()) {
155 tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
156 }
157 auto op = cast<LinalgOp>(generic_op.getOperation());
158 for (const Value result : op.getOutputBuffers()) {
159 if (!result_buffers.count(result)) continue;
160 if (tileGenericOp(op, tile_sizes, &b)) {
161 generic_op.erase();
162 return;
163 }
164 }
165 });
166 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
167 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
168
169 // Fuse producers of tiled linalg ops.
170 llvm::SmallDenseSet<Operation*> erase_set;
171 SmallVector<LinalgOp, 8> linalg_ops;
172 func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
173 for (LinalgOp op : llvm::reverse(linalg_ops)) {
174 for (OpOperand& inputOperand : op.getInputOpOperands()) {
175 linalg::Aliases aliases;
176 linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
177 if (auto info = fuseProducerOfBuffer(b, inputOperand, graph)) {
178 auto originalOp = info->originalProducer.getOperation();
179 erase_set.insert(originalOp);
180 auto originalOpInLinalgOpsVector = std::find_if(
181 linalg_ops.begin(), linalg_ops.end(),
182 [&](const Operation* op) { return op == originalOp; });
183 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
184 }
185 }
186
187 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
188 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
189 }
190 for (auto* e : erase_set) e->erase();
191 }
192
193 private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tile_sizes,OpBuilder * b)194 bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
195 auto loopType = use_parallel_loops_
196 ? linalg::LinalgTilingLoopType::ParallelLoops
197 : linalg::LinalgTilingLoopType::Loops;
198 auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
199 linalg::LinalgTilingOptions()
200 .setTileSizes(tile_sizes)
201 .setLoopType(loopType));
202 return tiled_generic_op.hasValue();
203 }
204
205 Option<bool> use_parallel_loops_{
206 *this, "use-parallel-loops",
207 llvm::cl::desc(
208 "Tiles GenericOp consumer to parallel loops before linalg fusion"),
209 llvm::cl::init(false)};
210
211 ListOption<unsigned> tile_sizes_{
212 *this, "tile-sizes",
213 llvm::cl::desc(
214 "Tile sizes by which to tile linalg generic before linalg fusion"),
215 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
216 };
217
218 } // namespace
219
createLhloFuseLinalgPass(bool use_parallel_loops,ArrayRef<unsigned> tile_sizes)220 std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
221 bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
222 return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
223 }
224
225 } // namespace lmhlo
226 } // namespace mlir
227