1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdio>
17 #include <iostream>
18
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29
30 namespace mlir {
31
32 namespace TF {
33
34 namespace {
35
36 // Note: This implements the fusions performed in the old Remapper Grappler
37 // pass. That pass has specific cases for GPU and based on different
38 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
39 // pass covers (some of) the general CPU case and at the moment does not account
40 // for any target-specific configurations.
41
42 // This pass is being ported over from the Grappler Remapper pass based on
43 // need/usage. File a bug to request porting over additional fusions.
44
45 // TODO(b/158265178): Support GPU-specific fusions.
46 // TODO(b/158266710): Support CPU MKL configurations.
47
48 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
49 // implementations to decrease the number of operations needed to perform a
50 // computation.
51 struct FusedKernelMatcherPass
52 : public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
53 void runOnFunction() override;
54 };
55
IsActivationFunction(Operation * op)56 bool IsActivationFunction(Operation *op) {
57 return isa<EluOp, ReluOp, Relu6Op>(op);
58 }
59
60 // Finds and returns an activation op that uses the result of `op`. If there are
61 // multiple such activations, one is returned (with no guarantee as to which
62 // one). If there are no activation functions that use the output, returns
63 // nullptr.
GetActivation(Value op)64 Operation *GetActivation(Value op) {
65 for (auto &use : op.getUses()) {
66 if (IsActivationFunction(use.getOwner())) return use.getOwner();
67 }
68 return nullptr;
69 }
70
71 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
72 // input. If there are multiple such BiasAdds, one is returned (with no
73 // guarantee as to which one). If there are no BiasAdds that use the output,
74 // returns a null BiasAddOp.
GetBiasAdd(Value op)75 BiasAddOp GetBiasAdd(Value op) {
76 for (auto &use : op.getUses()) {
77 auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
78 // If it's a BiasAdd, check that the conv op is the first input.
79 if (bias_add && bias_add.value() == op) return bias_add;
80 }
81 // No BiasAddOps found among uses.
82 return BiasAddOp();
83 }
84
85 // Performs a fusion of the following pattern(s), if possible:
86 // <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
87 //
88 // Note that fusion with activation is preferred, but a contraction and BiasAdd
89 // can also be replaced by a _FusedConv2D if there is no other activation
90 // function.
91 // i.e., this class also supports the following fusion:
92 // <Contraction> + BiasAdd -> <FusedContraction>
93 //
94 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
95 template <typename SrcOpT, typename FusedOpT>
96 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
97 public:
98 using OpRewritePattern<SrcOpT>::OpRewritePattern;
99 // Class users should override this method if there are any op-specific
100 // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const101 virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
102 PatternRewriter &rewriter) const {
103 return true;
104 }
105
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const106 LogicalResult matchAndRewrite(SrcOpT contraction,
107 PatternRewriter &rewriter) const override {
108 auto context = rewriter.getContext();
109 // If the contraction is used in multiple places, fusing it will only create
110 // more contraction nodes, which is slower.
111 if (!contraction.getResult().hasOneUse())
112 return rewriter.notifyMatchFailure(contraction,
113 "result is used by multiple ops");
114
115 BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
116 if (!bias_add) {
117 return rewriter.notifyMatchFailure(
118 contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
119 }
120
121 if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
122 return rewriter.notifyMatchFailure(
123 contraction, "cannot fuse with the subsequent BiasAdd op");
124 }
125
126 SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
127 SmallVector<Attribute, 2> fused_ops{StringAttr::get(
128 context, bias_add.getOperation()->getName().stripDialect())};
129
130 // BiasAdd may or may not feed into an activation function.
131 auto activation = GetActivation(bias_add);
132
133 // If there is an activation, only fuse it if this is the only op to use the
134 // result of the BiasAdd.
135 bool fuse_activation = activation && bias_add.output().hasOneUse();
136 Type result_type;
137
138 // Include info about the activation function if applicable.
139 if (fuse_activation) {
140 locations.push_back(activation->getLoc());
141 fused_ops.push_back(
142 StringAttr::get(context, activation->getName().stripDialect()));
143 result_type = activation->getResultTypes().front();
144 } else {
145 result_type = bias_add.getResult().getType();
146 }
147
148 auto fused_loc = rewriter.getFusedLoc(locations);
149
150 // The fused contraction has the same operands as the original contraction
151 // with `bias` from the BiasAddOp appended.
152 SmallVector<Value, 4> operands(contraction.operand_begin(),
153 contraction.operand_end());
154 operands.push_back(bias_add.bias());
155
156 // The fused contraction has the same attributes as the original
157 // contraction, with two additions: the list of ops which have been fused
158 // together; epsilon (only with FusedBatchNorm).
159 std::vector<NamedAttribute> attrs = contraction.getAttrs();
160 ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
161 attrs.push_back(
162 NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
163 // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
164 // here.
165 Attribute epsilon = rewriter.getF32FloatAttr(0);
166 attrs.push_back(
167 NamedAttribute(Identifier::get("epsilon", context), epsilon));
168
169 // Insert fused operation right before the BiasAdd operation to guarantee
170 // that bias value dominates the fused operation. We already verified that
171 // original operation has a single use, so this is safe to do.
172 auto *bias_add_op = bias_add.getOperation();
173 if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
174
175 Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
176 ValueRange(operands), attrs);
177 auto op_to_replace = fuse_activation ? activation : bias_add;
178 rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
179 return success();
180 }
181 };
182
183 // Performs a fusion of the following pattern(s), if possible:
184 // Conv2D + BiasAdd + <Activation> -> _FusedConv2D
185 class FuseConv2DBiasAdd
186 : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
187 public:
188 using FuseContractionWithBiasAdd<Conv2DOp,
189 _FusedConv2DOp>::FuseContractionWithBiasAdd;
190 // Verify that the Conv2D and BiasAdd data formats match. This is necessary
191 // for the ops to fuse correctly, the fused Conv2D op has one data format
192 // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const193 bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
194 PatternRewriter &rewriter) const override {
195 // Verify that the data formats match and are valid for fusion.
196 if (conv.data_format() != bias_add.data_format()) {
197 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
198 diag << "data format does not match Conv2D data format ("
199 << bias_add.data_format() << " vs " << conv.data_format() << ")";
200 });
201 return false;
202 }
203 // Verify the data type is supported.
204 if (!conv.T().isF32() && !conv.T().isF64()) {
205 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
206 diag << "supported data types for _FusedConv2D are float and double, "
207 << " but got " << conv.T();
208 });
209 return false;
210 }
211 return true;
212 }
213 };
214
215 // Performs a fusion of the following pattern(s), if possible:
216 // MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
217 class FuseMatMulBiasAdd
218 : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
219 using FuseContractionWithBiasAdd<MatMulOp,
220 _FusedMatMulOp>::FuseContractionWithBiasAdd;
221
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const222 bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
223 PatternRewriter &rewriter) const override {
224 // FusedMatMul kernel supports limited set of data types.
225 if (!matmul.T().isF32() && !matmul.T().isBF16()) {
226 (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
227 diag << "supported data types for _FusedMatMul are float and bfloat16, "
228 << " but got " << matmul.T();
229 });
230 return false;
231 }
232 return true;
233 }
234 };
235
runOnFunction()236 void FusedKernelMatcherPass::runOnFunction() {
237 OwningRewritePatternList patterns;
238 auto func = getFunction();
239 patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
240
241 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
242 }
243
244 } // namespace
245
CreateFusedKernelMatcherPass()246 std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
247 return std::make_unique<FusedKernelMatcherPass>();
248 }
249
250 static PassRegistration<FusedKernelMatcherPass> pass(
251 "tf-fused-kernel-matcher",
252 "Matches computations corresponding to optimized fused kernels");
253
254 } // namespace TF
255
256 } // namespace mlir
257