• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
25 #include "mlir/Dialect/SCF/SCF.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Interfaces/ViewLikeInterface.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 
32 namespace mlir {
33 namespace lmhlo {
34 namespace {
35 
36 using linalg::LinalgOp;
37 
38 class LhloFuseLinalgPass
39     : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const40   void getDependentDialects(DialectRegistry& registry) const override {
41     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
42   }
43 
44  public:
45   LhloFuseLinalgPass() = default;
LhloFuseLinalgPass(const LhloFuseLinalgPass &)46   LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
LhloFuseLinalgPass(bool use_parallel_loops,llvm::ArrayRef<unsigned> tile_sizes)47   LhloFuseLinalgPass(bool use_parallel_loops,
48                      llvm::ArrayRef<unsigned> tile_sizes) {
49     tile_sizes_ = tile_sizes;
50     use_parallel_loops_.setValue(use_parallel_loops);
51   }
52 
runOnFunction()53   void runOnFunction() override {
54     auto func = getFunction();
55 
56     // TODO(pifon): Remove assumption that the function has a single block.
57     if (!llvm::hasSingleElement(func)) {
58       emitError(func.getLoc(), "The function needs to have a single block.");
59       signalPassFailure();
60       return;
61     }
62 
63     // The fusion in Linalg is currently possible only when the consumer op is
64     // tiled. In order to greedily fuse the ops, we have to start from the tiled
65     // root linalg ops, i.e. linalg ops that write to output buffers of the
66     // function or are returned in case of escaping allocations.
67     llvm::SmallDenseSet<Value> result_buffers;
68     for (auto func_arg : func.getArguments()) {
69       result_buffers.insert(func_arg);
70     }
71     for (auto& block : func) {
72       auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
73       if (!returnOp) continue;
74       for (auto operand : returnOp.getOperands()) {
75         result_buffers.insert(operand);
76       }
77     }
78     // Resolve aliasing operations (like casts) on the result to identify
79     // results. This only handles escaping results.
80     // TODO(herhut): Use BufferizeAliasAnalysis for this.
81     llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
82                                          result_buffers.end());
83     while (!worklist.empty()) {
84       Value result = worklist.pop_back_val();
85       auto definingOp = result.getDefiningOp();
86       if (!definingOp) {
87         continue;
88       }
89 
90       if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
91         auto alias = viewLike.getViewSource();
92         if (result_buffers.insert(alias).second) {
93           worklist.push_back(alias);
94         }
95         continue;
96       }
97 
98       if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
99         auto alias = tensor_load.memref();
100         if (result_buffers.insert(alias).second) {
101           worklist.push_back(alias);
102         }
103         continue;
104       }
105 
106       if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
107         auto alias = tensor_to_memref.tensor();
108         if (result_buffers.insert(alias).second) {
109           worklist.push_back(alias);
110         }
111         continue;
112       }
113 
114       if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
115         auto alias = tensor_cast.source();
116         if (result_buffers.insert(alias).second) {
117           worklist.push_back(alias);
118         }
119         continue;
120       }
121 
122       if (auto regionInterface =
123               dyn_cast<RegionBranchOpInterface>(definingOp)) {
124         for (Region& region : regionInterface.getOperation()->getRegions()) {
125           // Only consider regions that can return to the parent region.
126           SmallVector<RegionSuccessor, 2> successorRegions;
127           regionInterface.getSuccessorRegions(region.getRegionNumber(),
128                                               successorRegions);
129           if (llvm::none_of(successorRegions, [&](auto successorRegion) {
130                 return successorRegion.isParent();
131               }))
132             continue;
133 
134           // Iterate over all immediate terminators and record the values
135           // corresponding to result_buffers of interest.
136           for (Block& block : region) {
137             if (block.empty()) continue;
138             Operation& operation = block.back();
139             if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
140             auto idx = result.dyn_cast<OpResult>().getResultNumber();
141             if (result_buffers.insert(operation.getOperand(idx)).second) {
142               worklist.push_back(operation.getOperand(idx));
143             }
144           }
145         }
146       }
147     }
148 
149     MLIRContext* ctx = func.getContext();
150     OpBuilder b(func);
151     func.walk([&](linalg::GenericOp generic_op) {
152       SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
153                                          tile_sizes_.end());
154       if (tile_sizes.empty()) {
155         tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
156       }
157       auto op = cast<LinalgOp>(generic_op.getOperation());
158       for (const Value result : op.getOutputBuffers()) {
159         if (!result_buffers.count(result)) continue;
160         if (tileGenericOp(op, tile_sizes, &b)) {
161           generic_op.erase();
162           return;
163         }
164       }
165     });
166     auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
167     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
168 
169     // Fuse producers of tiled linalg ops.
170     llvm::SmallDenseSet<Operation*> erase_set;
171     SmallVector<LinalgOp, 8> linalg_ops;
172     func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
173     for (LinalgOp op : llvm::reverse(linalg_ops)) {
174       for (OpOperand& inputOperand : op.getInputOpOperands()) {
175         linalg::Aliases aliases;
176         linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
177         if (auto info = fuseProducerOfBuffer(b, inputOperand, graph)) {
178           auto originalOp = info->originalProducer.getOperation();
179           erase_set.insert(originalOp);
180           auto originalOpInLinalgOpsVector = std::find_if(
181               linalg_ops.begin(), linalg_ops.end(),
182               [&](const Operation* op) { return op == originalOp; });
183           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
184         }
185       }
186 
187       auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
188       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
189     }
190     for (auto* e : erase_set) e->erase();
191   }
192 
193  private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tile_sizes,OpBuilder * b)194   bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
195     auto loopType = use_parallel_loops_
196                         ? linalg::LinalgTilingLoopType::ParallelLoops
197                         : linalg::LinalgTilingLoopType::Loops;
198     auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
199                                                  linalg::LinalgTilingOptions()
200                                                      .setTileSizes(tile_sizes)
201                                                      .setLoopType(loopType));
202     return tiled_generic_op.hasValue();
203   }
204 
205   Option<bool> use_parallel_loops_{
206       *this, "use-parallel-loops",
207       llvm::cl::desc(
208           "Tiles GenericOp consumer to parallel loops before linalg fusion"),
209       llvm::cl::init(false)};
210 
211   ListOption<unsigned> tile_sizes_{
212       *this, "tile-sizes",
213       llvm::cl::desc(
214           "Tile sizes by which to tile linalg generic before linalg fusion"),
215       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
216 };
217 
218 }  // namespace
219 
createLhloFuseLinalgPass(bool use_parallel_loops,ArrayRef<unsigned> tile_sizes)220 std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
221     bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
222   return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
223 }
224 
225 }  // namespace lmhlo
226 }  // namespace mlir
227