1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/MLIRContext.h" // from @llvm-project
20 #include "mlir/IR/UseDefLists.h" // from @llvm-project
21 #include "mlir/Pass/Pass.h" // from @llvm-project
22 #include "mlir/Support/LogicalResult.h" // from @llvm-project
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25
26 namespace mlir {
27 namespace TF {
28 namespace {
29
30 // -------------------------------------------------------------------------- //
31 // Fuse ContractionFusableInterface operations into contraction operation.
32 // -------------------------------------------------------------------------- //
33
34 template <typename BaseOp, typename FusedOp>
35 class FuseIntoContractionOp : public RewritePattern {
36 public:
FuseIntoContractionOp()37 FuseIntoContractionOp()
38 : RewritePattern(PatternBenefit(1), MatchAnyOpTypeTag()) {}
39
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const40 LogicalResult matchAndRewrite(Operation *op,
41 PatternRewriter &rewriter) const override {
42 auto fusable = dyn_cast<ContractionFusableInterface>(op);
43 if (!fusable) return failure();
44
45 auto failed = [&](Twine message) -> LogicalResult {
46 return rewriter.notifyMatchFailure(op, message);
47 };
48
49 // Check if the operation can be fused.
50 Optional<ContractionFusion> fusion = fusable.GetContractionFusion();
51 if (!fusion.hasValue()) {
52 return failed("returned empty contraction fusion specification");
53 }
54
55 // Check if preceeding operation is a BaseOp or FusedOp that we can use for
56 // fusion.
57 Operation *fuse_into = nullptr;
58 Value operand = op->getOperand(0);
59
60 if (BaseOp base_op = operand.getDefiningOp<BaseOp>()) {
61 fuse_into = base_op.getOperation();
62 } else if (FusedOp fused_op = operand.getDefiningOp<FusedOp>()) {
63 fuse_into = fused_op.getOperation();
64 } else {
65 return failed("input to the fusable op must be a " +
66 BaseOp::getOperationName() + " or a " +
67 FusedOp::getOperationName());
68 }
69
70 // Operand result must have one use, because we do not want to compute
71 // tensor contraction twice.
72 if (!fuse_into->getResult(0).hasOneUse()) {
73 return failed("fused into op result must have one use");
74 }
75
76 MLIRContext *ctx = op->getContext();
77
78 // Build a fused MatMul operation from a base MatMul and a fusion.
79 SmallVector<Location, 3> locations = {fuse_into->getLoc(), op->getLoc()};
80 Location loc = rewriter.getFusedLoc(locations);
81
82 // Fusion can't change the type of a fused operation.
83 Type result_ty = fuse_into->getResult(0).getType();
84
85 // Copy all operands from a base op and add additional fusion arguments.
86 SmallVector<Value, 3> operands(fuse_into->getOperands());
87 for (int idx : fusion->additional_arguments) {
88 operands.push_back(op->getOperand(idx));
89 }
90
91 // Copy attributes from a base op that we fuse into (e.g. copy all
92 // MatMul or Conv attributes to the fused operation).
93 SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
94 fuse_into->getAttrs().end());
95
96 // Add fusion specific additional attributes.
97 for (auto attr : fusion->additional_attributes) {
98 attrs.push_back(attr);
99 }
100
101 // Add a fused output kernel name to the list of fusions.
102 Identifier fusion_id = Identifier::get("fusion", ctx);
103 StringAttr fusion_name = StringAttr::get(ctx, fusion->output_kernel);
104
105 auto is_fusion = [&](const NamedAttribute &attr) -> bool {
106 return attr.first == fusion_id;
107 };
108
109 if (isa<BaseOp>(fuse_into)) {
110 NamedAttribute fusion_attr(fusion_id, ArrayAttr::get(ctx, {fusion_name}));
111 attrs.push_back(fusion_attr);
112
113 } else {
114 ArrayAttr arr =
115 llvm::find_if(attrs, is_fusion)->second.template cast<ArrayAttr>();
116 llvm::erase_if(attrs, is_fusion);
117
118 auto rng = arr.getAsRange<Attribute>();
119 SmallVector<Attribute, 4> updated(rng.begin(), rng.end());
120 updated.push_back(fusion_name);
121
122 attrs.push_back(NamedAttribute(fusion_id, ArrayAttr::get(ctx, updated)));
123 }
124
125 // Update all uses of a fusable op with a new fused operation.
126 Value fused = rewriter.create<FusedOp>(loc, result_ty, operands, attrs);
127 rewriter.replaceOp(op, {fused});
128
129 return failure();
130 }
131 };
132
133 // -------------------------------------------------------------------------- //
134
135 using FuseIntoMatMulOp = FuseIntoContractionOp<MatMulOp, _JitFusedMatMulOp>;
136
137 struct ContractionFusionPass
138 : public PassWrapper<ContractionFusionPass, FunctionPass> {
139 void runOnFunction() override;
140 };
141
runOnFunction()142 void ContractionFusionPass::runOnFunction() {
143 FuncOp func = getFunction();
144
145 OwningRewritePatternList patterns;
146 patterns.insert<FuseIntoMatMulOp>();
147 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
148 }
149
150 } // namespace
151
CreateContractionFusionPass()152 std::unique_ptr<OperationPass<FuncOp>> CreateContractionFusionPass() {
153 return std::make_unique<ContractionFusionPass>();
154 }
155
156 static PassRegistration<ContractionFusionPass> pass(
157 "tf-contraction-fusion",
158 "Fuses operations implementing ContractionFusionInterface into the "
159 "contraction operations");
160
161 } // namespace TF
162 } // namespace mlir
163