• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
25 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/Interfaces/ViewLikeInterface.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 
34 namespace mlir {
35 namespace lmhlo {
36 namespace {
37 
38 using linalg::LinalgOp;
39 
40 class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
getDependentDialects(DialectRegistry & registry) const41   void getDependentDialects(DialectRegistry& registry) const override {
42     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
43   }
44 
45  public:
46   LhloFuseLinalgPass() = default;
LhloFuseLinalgPass(const LhloFuseLinalgPass &)47   LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
LhloFuseLinalgPass(bool use_parallel_loops,llvm::ArrayRef<unsigned> tile_sizes)48   LhloFuseLinalgPass(bool use_parallel_loops,
49                      llvm::ArrayRef<unsigned> tile_sizes) {
50     tile_sizes_ = tile_sizes;
51     use_parallel_loops_.setValue(use_parallel_loops);
52   }
53 
runOnFunction()54   void runOnFunction() override {
55     auto func = getFunction();
56 
57     // TODO(pifon): Remove assumption that the function has a single block.
58     if (!llvm::hasSingleElement(func)) {
59       emitError(func.getLoc(), "The function needs to have a single block.");
60       signalPassFailure();
61       return;
62     }
63 
64     // The fusion in Linalg is currently possible only when the consumer op is
65     // tiled. In order to greedily fuse the ops, we have to start from the tiled
66     // root linalg ops, i.e. linalg ops that write to output buffers of the
67     // function or are returned in case of escaping allocations.
68     llvm::SmallDenseSet<Value> result_buffers;
69     for (auto func_arg : func.getArguments()) {
70       result_buffers.insert(func_arg);
71     }
72     for (auto& block : func) {
73       auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
74       if (!returnOp) continue;
75       for (auto operand : returnOp.getOperands()) {
76         result_buffers.insert(operand);
77       }
78     }
79     // Resolve aliasing operations (like casts) on the result to identify
80     // results. This only handles escaping results.
81     // TODO(herhut): Use BufferizeAliasAnalysis for this.
82     llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
83                                          result_buffers.end());
84     while (!worklist.empty()) {
85       Value result = worklist.pop_back_val();
86       auto definingOp = result.getDefiningOp();
87       if (!definingOp) {
88         continue;
89       }
90 
91       if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
92         auto alias = viewLike.getViewSource();
93         if (result_buffers.insert(alias).second) {
94           worklist.push_back(alias);
95         }
96         continue;
97       }
98 
99       if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
100         auto alias = tensor_load.memref();
101         if (result_buffers.insert(alias).second) {
102           worklist.push_back(alias);
103         }
104         continue;
105       }
106 
107       if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
108         auto alias = tensor_to_memref.tensor();
109         if (result_buffers.insert(alias).second) {
110           worklist.push_back(alias);
111         }
112         continue;
113       }
114 
115       if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
116         auto alias = tensor_cast.source();
117         if (result_buffers.insert(alias).second) {
118           worklist.push_back(alias);
119         }
120         continue;
121       }
122 
123       if (auto regionInterface =
124               dyn_cast<RegionBranchOpInterface>(definingOp)) {
125         for (Region& region : regionInterface.getOperation()->getRegions()) {
126           // Only consider regions that can return to the parent region.
127           SmallVector<RegionSuccessor, 2> successorRegions;
128           regionInterface.getSuccessorRegions(region.getRegionNumber(),
129                                               successorRegions);
130           if (llvm::none_of(successorRegions, [&](auto successorRegion) {
131                 return successorRegion.isParent();
132               }))
133             continue;
134 
135           // Iterate over all immediate terminators and record the values
136           // corresponding to result_buffers of interest.
137           for (Block& block : region) {
138             if (block.empty()) continue;
139             Operation& operation = block.back();
140             if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
141             auto idx = result.dyn_cast<OpResult>().getResultNumber();
142             if (result_buffers.insert(operation.getOperand(idx)).second) {
143               worklist.push_back(operation.getOperand(idx));
144             }
145           }
146         }
147       }
148     }
149 
150     MLIRContext* ctx = func.getContext();
151     OpBuilder b(func);
152     func.walk([&](linalg::GenericOp generic_op) {
153       SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
154                                          tile_sizes_.end());
155       if (tile_sizes.empty()) {
156         tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
157       }
158       auto op = cast<LinalgOp>(generic_op.getOperation());
159       for (OpOperand* op_operand : op.getOutputBufferOperands()) {
160         if (!result_buffers.count(op_operand->get())) continue;
161         if (tileGenericOp(op, tile_sizes, &b)) {
162           generic_op.erase();
163           return;
164         }
165       }
166     });
167     auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
168     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
169 
170     // Fuse producers of tiled linalg ops.
171     llvm::SmallDenseSet<Operation*> erase_set;
172     SmallVector<LinalgOp, 8> linalg_ops;
173     func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
174     for (LinalgOp op : llvm::reverse(linalg_ops)) {
175       for (OpOperand* inputOperand : op.getInputOperands()) {
176         linalg::Aliases aliases;
177         linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
178         if (auto info = fuseProducerOfBuffer(b, *inputOperand, graph)) {
179           auto originalOp = info->originalProducer.getOperation();
180           erase_set.insert(originalOp);
181           auto originalOpInLinalgOpsVector = std::find_if(
182               linalg_ops.begin(), linalg_ops.end(),
183               [&](const Operation* op) { return op == originalOp; });
184           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
185         }
186       }
187 
188       auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
189       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
190     }
191     for (auto* e : erase_set) e->erase();
192   }
193 
194  private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tile_sizes,OpBuilder * b)195   bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
196     auto loopType = use_parallel_loops_
197                         ? linalg::LinalgTilingLoopType::ParallelLoops
198                         : linalg::LinalgTilingLoopType::Loops;
199     auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
200                                                  linalg::LinalgTilingOptions()
201                                                      .setTileSizes(tile_sizes)
202                                                      .setLoopType(loopType));
203     return tiled_generic_op.hasValue();
204   }
205 };
206 
207 }  // namespace
208 
createLhloFuseLinalgPass(bool use_parallel_loops,ArrayRef<unsigned> tile_sizes)209 std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
210     bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
211   return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
212 }
213 
214 }  // namespace lmhlo
215 }  // namespace mlir
216