1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18
19 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
20 #include "absl/memory/memory.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h" // TF:llvm-project
22 #include "mlir/EDSC/Helpers.h" // TF:llvm-project
23 #include "mlir/Pass/Pass.h" // TF:llvm-project
24
25 namespace mlir {
26 namespace xla_lhlo {
27 namespace {
28
29 using linalg::LinalgOp;
30
31 struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
runOnFunctionmlir::xla_lhlo::__anoncee307aa0111::LhloFuseLinalg32 void runOnFunction() override {
33 auto func = getFunction();
34
35 // TODO(pifon): Remove assumption that the function has a single block.
36 if (func.getBlocks().size() != 1) {
37 emitError(func.getLoc(), "The function needs to have a single block.");
38 signalPassFailure();
39 return;
40 }
41
42 // The fusion in Linalg is currently possible only when the consumer op is
43 // tiled. In order to greedily fuse the ops, we have to start from the tiled
44 // root linalg ops, i.e. linalg ops that write to output buffers of the
45 // function.
46 llvm::SmallDenseSet<Value> func_args;
47 for (auto func_arg : func.getArguments()) {
48 func_args.insert(func_arg);
49 }
50 OpBuilder b(func);
51 OperationFolder folder(func.getContext());
52 func.walk([&](linalg::GenericOp generic_op) {
53 const SmallVector<int64_t, 2> tile_sizes(
54 generic_op.getNumInputsAndOutputs(), 1);
55 auto op = cast<LinalgOp>(generic_op.getOperation());
56 for (const Value result : op.getOutputBuffers()) {
57 if (!func_args.count(result)) continue;
58 if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
59 &folder)) {
60 generic_op.erase();
61 return;
62 }
63 }
64 });
65
66 // Fuse producers of tiled linalg ops.
67 llvm::SmallDenseSet<Operation*> erase_set;
68 SmallVector<Operation*, 8> linalg_ops;
69 func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
70 for (auto* op : llvm::reverse(linalg_ops)) {
71 for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
72 linalg::Aliases aliases;
73 linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
74 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
75 auto originalOp = info->originalProducer.getOperation();
76 erase_set.insert(originalOp);
77 auto originalOpInLinalgOpsVector = std::find_if(
78 linalg_ops.begin(), linalg_ops.end(),
79 [&](const Operation* op) { return op == originalOp; });
80 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
81 }
82 }
83 }
84 for (auto* e : erase_set) e->erase();
85 }
86 };
87
88 } // namespace
89
createLhloFuseLinalg()90 std::unique_ptr<OpPassBase<FuncOp>> createLhloFuseLinalg() {
91 return absl::make_unique<LhloFuseLinalg>();
92 }
93
94 static PassRegistration<LhloFuseLinalg> legalize_pass(
95 "lhlo-fuse-linalg",
96 "Greedily fuse linalg ops obtained after LHLO lowering.");
97
98 } // namespace xla_lhlo
99 } // namespace mlir
100