1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdio>
17 #include <iostream>
18
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29
30 namespace mlir {
31
32 namespace TF {
33
34 namespace {
35
36 // Note: This implements the fusions performed in the old Remapper Grappler
37 // pass. That pass has specific cases for GPU and based on different
38 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
39 // pass covers (some of) the general CPU case and at the moment does not account
40 // for any target-specific configurations.
41
42 // This pass is being ported over from the Grappler Remapper pass based on
43 // need/usage. File a bug to request porting over additional fusions.
44
45 // TODO(b/158265178): Support GPU-specific fusions.
46 // TODO(b/158266710): Support CPU MKL configurations.
47
48 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
49 // implementations to decrease the number of operations needed to perform a
50 // computation.
51 struct FusedKernelMatcherPass
52 : public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
getArgumentmlir::TF::__anon18d97fbc0111::FusedKernelMatcherPass53 StringRef getArgument() const final { return "tf-fused-kernel-matcher"; }
54
getDescriptionmlir::TF::__anon18d97fbc0111::FusedKernelMatcherPass55 StringRef getDescription() const final {
56 return "Matches computations corresponding to optimized fused kernels";
57 }
58
59 void runOnFunction() override;
60 };
61
IsActivationFunction(Operation * op)62 bool IsActivationFunction(Operation *op) {
63 return isa<EluOp, ReluOp, Relu6Op>(op);
64 }
65
66 // Finds and returns an activation op that uses the result of `op`. If there are
67 // multiple such activations, one is returned (with no guarantee as to which
68 // one). If there are no activation functions that use the output, returns
69 // nullptr.
GetActivation(Value op)70 Operation *GetActivation(Value op) {
71 for (auto &use : op.getUses()) {
72 if (IsActivationFunction(use.getOwner())) return use.getOwner();
73 }
74 return nullptr;
75 }
76
77 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
78 // input. If there are multiple such BiasAdds, one is returned (with no
79 // guarantee as to which one). If there are no BiasAdds that use the output,
80 // returns a null BiasAddOp.
GetBiasAdd(Value op)81 BiasAddOp GetBiasAdd(Value op) {
82 for (auto &use : op.getUses()) {
83 auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
84 // If it's a BiasAdd, check that the conv op is the first input.
85 if (bias_add && bias_add.value() == op) return bias_add;
86 }
87 // No BiasAddOps found among uses.
88 return BiasAddOp();
89 }
90
91 // Performs a fusion of the following pattern(s), if possible:
92 // <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
93 //
94 // Note that fusion with activation is preferred, but a contraction and BiasAdd
95 // can also be replaced by a _FusedConv2D if there is no other activation
96 // function.
97 // i.e., this class also supports the following fusion:
98 // <Contraction> + BiasAdd -> <FusedContraction>
99 //
100 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
101 template <typename SrcOpT, typename FusedOpT>
102 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
103 public:
104 using OpRewritePattern<SrcOpT>::OpRewritePattern;
105 // Class users should override this method if there are any op-specific
106 // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const107 virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
108 PatternRewriter &rewriter) const {
109 return true;
110 }
111
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const112 LogicalResult matchAndRewrite(SrcOpT contraction,
113 PatternRewriter &rewriter) const override {
114 auto context = rewriter.getContext();
115
116 // We do support fusion only if the contraction operation is inside one of
117 // the expected operations with regions. Other operations can have semantics
118 // that is not compatible with fusion (e.g. region compilation).
119 if (!isa<FuncOp, IfOp, WhileOp>(contraction->getParentOp())) {
120 return rewriter.notifyMatchFailure(
121 contraction,
122 "fused operation must be nested inside a function, If or While");
123 }
124
125 // If the contraction is used in multiple places, fusing it will only create
126 // more contraction nodes, which is slower.
127 if (!contraction.getResult().hasOneUse())
128 return rewriter.notifyMatchFailure(contraction,
129 "result is used by multiple ops");
130
131 BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
132 if (!bias_add) {
133 return rewriter.notifyMatchFailure(
134 contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
135 }
136
137 if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
138 return rewriter.notifyMatchFailure(
139 contraction, "cannot fuse with the subsequent BiasAdd op");
140 }
141
142 SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
143 SmallVector<Attribute, 2> fused_ops{StringAttr::get(
144 context, bias_add.getOperation()->getName().stripDialect())};
145
146 // BiasAdd may or may not feed into an activation function.
147 auto activation = GetActivation(bias_add);
148
149 // If there is an activation, only fuse it if this is the only op to use the
150 // result of the BiasAdd.
151 bool fuse_activation = activation && bias_add.output().hasOneUse();
152 Type result_type;
153
154 // Include info about the activation function if applicable.
155 if (fuse_activation) {
156 locations.push_back(activation->getLoc());
157 fused_ops.push_back(
158 StringAttr::get(context, activation->getName().stripDialect()));
159 result_type = activation->getResultTypes().front();
160 } else {
161 result_type = bias_add.getResult().getType();
162 }
163
164 auto fused_loc = rewriter.getFusedLoc(locations);
165
166 // The fused contraction has the same operands as the original contraction
167 // with `bias` from the BiasAddOp appended.
168 SmallVector<Value, 4> operands(contraction.operand_begin(),
169 contraction.operand_end());
170 operands.push_back(bias_add.bias());
171
172 // The fused contraction has the same attributes as the original
173 // contraction, with two additions: the list of ops which have been fused
174 // together; epsilon (only with FusedBatchNorm).
175 std::vector<NamedAttribute> attrs = contraction->getAttrs();
176 ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
177 attrs.push_back(
178 NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
179 // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
180 // here.
181 Attribute epsilon = rewriter.getF32FloatAttr(0);
182 attrs.push_back(
183 NamedAttribute(Identifier::get("epsilon", context), epsilon));
184
185 // Insert fused operation right before the BiasAdd operation to guarantee
186 // that bias value dominates the fused operation. We already verified that
187 // original operation has a single use, so this is safe to do.
188 auto *bias_add_op = bias_add.getOperation();
189 if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
190
191 Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
192 ValueRange(operands), attrs);
193 auto op_to_replace = fuse_activation ? activation : bias_add;
194 rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
195 return success();
196 }
197 };
198
199 // Performs a fusion of the following pattern(s), if possible:
200 // Conv2D + BiasAdd + <Activation> -> _FusedConv2D
201 class FuseConv2DBiasAdd
202 : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
203 public:
204 using FuseContractionWithBiasAdd<Conv2DOp,
205 _FusedConv2DOp>::FuseContractionWithBiasAdd;
206 // Verify that the Conv2D and BiasAdd data formats match. This is necessary
207 // for the ops to fuse correctly, the fused Conv2D op has one data format
208 // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const209 bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
210 PatternRewriter &rewriter) const override {
211 // Verify that the data formats match and are valid for fusion.
212 if (conv.data_format() != bias_add.data_format()) {
213 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
214 diag << "data format does not match Conv2D data format ("
215 << bias_add.data_format() << " vs " << conv.data_format() << ")";
216 });
217 return false;
218 }
219 // Verify the data type is supported.
220 if (!conv.T().isF32() && !conv.T().isF64()) {
221 (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
222 diag << "supported data types for _FusedConv2D are float and double, "
223 << " but got " << conv.T();
224 });
225 return false;
226 }
227 return true;
228 }
229 };
230
231 // Performs a fusion of the following pattern(s), if possible:
232 // MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
233 class FuseMatMulBiasAdd
234 : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
235 using FuseContractionWithBiasAdd<MatMulOp,
236 _FusedMatMulOp>::FuseContractionWithBiasAdd;
237
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const238 bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
239 PatternRewriter &rewriter) const override {
240 // FusedMatMul kernel supports limited set of data types.
241 if (!matmul.T().isF32() && !matmul.T().isBF16()) {
242 (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
243 diag << "supported data types for _FusedMatMul are float and bfloat16, "
244 << " but got " << matmul.T();
245 });
246 return false;
247 }
248 return true;
249 }
250 };
251
runOnFunction()252 void FusedKernelMatcherPass::runOnFunction() {
253 OwningRewritePatternList patterns(&getContext());
254 auto func = getFunction();
255 patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
256
257 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
258 }
259
260 } // namespace
261
CreateFusedKernelMatcherPass()262 std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
263 return std::make_unique<FusedKernelMatcherPass>();
264 }
265
266 static PassRegistration<FusedKernelMatcherPass> pass;
267
268 } // namespace TF
269
270 } // namespace mlir
271