1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
25 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/Interfaces/ViewLikeInterface.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34 namespace mlir {
35 namespace lmhlo {
36 namespace {
37
38 using linalg::LinalgOp;
39
40 class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
getDependentDialects(DialectRegistry & registry) const41 void getDependentDialects(DialectRegistry& registry) const override {
42 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
43 }
44
45 public:
46 LhloFuseLinalgPass() = default;
LhloFuseLinalgPass(const LhloFuseLinalgPass &)47 LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
LhloFuseLinalgPass(bool use_parallel_loops,llvm::ArrayRef<unsigned> tile_sizes)48 LhloFuseLinalgPass(bool use_parallel_loops,
49 llvm::ArrayRef<unsigned> tile_sizes) {
50 tile_sizes_ = tile_sizes;
51 use_parallel_loops_.setValue(use_parallel_loops);
52 }
53
runOnFunction()54 void runOnFunction() override {
55 auto func = getFunction();
56
57 // TODO(pifon): Remove assumption that the function has a single block.
58 if (!llvm::hasSingleElement(func)) {
59 emitError(func.getLoc(), "The function needs to have a single block.");
60 signalPassFailure();
61 return;
62 }
63
64 // The fusion in Linalg is currently possible only when the consumer op is
65 // tiled. In order to greedily fuse the ops, we have to start from the tiled
66 // root linalg ops, i.e. linalg ops that write to output buffers of the
67 // function or are returned in case of escaping allocations.
68 llvm::SmallDenseSet<Value> result_buffers;
69 for (auto func_arg : func.getArguments()) {
70 result_buffers.insert(func_arg);
71 }
72 for (auto& block : func) {
73 auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
74 if (!returnOp) continue;
75 for (auto operand : returnOp.getOperands()) {
76 result_buffers.insert(operand);
77 }
78 }
79 // Resolve aliasing operations (like casts) on the result to identify
80 // results. This only handles escaping results.
81 // TODO(herhut): Use BufferizeAliasAnalysis for this.
82 llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
83 result_buffers.end());
84 while (!worklist.empty()) {
85 Value result = worklist.pop_back_val();
86 auto definingOp = result.getDefiningOp();
87 if (!definingOp) {
88 continue;
89 }
90
91 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
92 auto alias = viewLike.getViewSource();
93 if (result_buffers.insert(alias).second) {
94 worklist.push_back(alias);
95 }
96 continue;
97 }
98
99 if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
100 auto alias = tensor_load.memref();
101 if (result_buffers.insert(alias).second) {
102 worklist.push_back(alias);
103 }
104 continue;
105 }
106
107 if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
108 auto alias = tensor_to_memref.tensor();
109 if (result_buffers.insert(alias).second) {
110 worklist.push_back(alias);
111 }
112 continue;
113 }
114
115 if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
116 auto alias = tensor_cast.source();
117 if (result_buffers.insert(alias).second) {
118 worklist.push_back(alias);
119 }
120 continue;
121 }
122
123 if (auto regionInterface =
124 dyn_cast<RegionBranchOpInterface>(definingOp)) {
125 for (Region& region : regionInterface.getOperation()->getRegions()) {
126 // Only consider regions that can return to the parent region.
127 SmallVector<RegionSuccessor, 2> successorRegions;
128 regionInterface.getSuccessorRegions(region.getRegionNumber(),
129 successorRegions);
130 if (llvm::none_of(successorRegions, [&](auto successorRegion) {
131 return successorRegion.isParent();
132 }))
133 continue;
134
135 // Iterate over all immediate terminators and record the values
136 // corresponding to result_buffers of interest.
137 for (Block& block : region) {
138 if (block.empty()) continue;
139 Operation& operation = block.back();
140 if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
141 auto idx = result.dyn_cast<OpResult>().getResultNumber();
142 if (result_buffers.insert(operation.getOperand(idx)).second) {
143 worklist.push_back(operation.getOperand(idx));
144 }
145 }
146 }
147 }
148 }
149
150 MLIRContext* ctx = func.getContext();
151 OpBuilder b(func);
152 func.walk([&](linalg::GenericOp generic_op) {
153 SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
154 tile_sizes_.end());
155 if (tile_sizes.empty()) {
156 tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
157 }
158 auto op = cast<LinalgOp>(generic_op.getOperation());
159 for (OpOperand* op_operand : op.getOutputBufferOperands()) {
160 if (!result_buffers.count(op_operand->get())) continue;
161 if (tileGenericOp(op, tile_sizes, &b)) {
162 generic_op.erase();
163 return;
164 }
165 }
166 });
167 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
168 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
169
170 // Fuse producers of tiled linalg ops.
171 llvm::SmallDenseSet<Operation*> erase_set;
172 SmallVector<LinalgOp, 8> linalg_ops;
173 func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
174 for (LinalgOp op : llvm::reverse(linalg_ops)) {
175 for (OpOperand* inputOperand : op.getInputOperands()) {
176 linalg::Aliases aliases;
177 linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
178 if (auto info = fuseProducerOfBuffer(b, *inputOperand, graph)) {
179 auto originalOp = info->originalProducer.getOperation();
180 erase_set.insert(originalOp);
181 auto originalOpInLinalgOpsVector = std::find_if(
182 linalg_ops.begin(), linalg_ops.end(),
183 [&](const Operation* op) { return op == originalOp; });
184 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
185 }
186 }
187
188 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
189 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
190 }
191 for (auto* e : erase_set) e->erase();
192 }
193
194 private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tile_sizes,OpBuilder * b)195 bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
196 auto loopType = use_parallel_loops_
197 ? linalg::LinalgTilingLoopType::ParallelLoops
198 : linalg::LinalgTilingLoopType::Loops;
199 auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
200 linalg::LinalgTilingOptions()
201 .setTileSizes(tile_sizes)
202 .setLoopType(loopType));
203 return tiled_generic_op.hasValue();
204 }
205 };
206
207 } // namespace
208
createLhloFuseLinalgPass(bool use_parallel_loops,ArrayRef<unsigned> tile_sizes)209 std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
210 bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
211 return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
212 }
213
214 } // namespace lmhlo
215 } // namespace mlir
216