• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdio>
17 #include <iostream>
18 
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 
30 namespace mlir {
31 
32 namespace TF {
33 
34 namespace {
35 
36 // Note: This implements the fusions performed in the old Remapper Grappler
37 // pass. That pass has specific cases for GPU and based on different
38 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
39 // pass covers (some of) the general CPU case and at the moment does not account
40 // for any target-specific configurations.
41 
42 // This pass is being ported over from the Grappler Remapper pass based on
43 // need/usage. File a bug to request porting over additional fusions.
44 
45 // TODO(b/158265178): Support GPU-specific fusions.
46 // TODO(b/158266710): Support CPU MKL configurations.
47 
48 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
49 // implementations to decrease the number of operations needed to perform a
50 // computation.
51 struct FusedKernelMatcherPass
52     : public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
53   void runOnFunction() override;
54 };
55 
IsActivationFunction(Operation * op)56 bool IsActivationFunction(Operation *op) {
57   return isa<EluOp, ReluOp, Relu6Op>(op);
58 }
59 
60 // Finds and returns an activation op that uses the result of `op`. If there are
61 // multiple such activations, one is returned (with no guarantee as to which
62 // one). If there are no activation functions that use the output, returns
63 // nullptr.
GetActivation(Value op)64 Operation *GetActivation(Value op) {
65   for (auto &use : op.getUses()) {
66     if (IsActivationFunction(use.getOwner())) return use.getOwner();
67   }
68   return nullptr;
69 }
70 
71 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
72 // input. If there are multiple such BiasAdds, one is returned (with no
73 // guarantee as to which one). If there are no BiasAdds that use the output,
74 // returns a null BiasAddOp.
GetBiasAdd(Value op)75 BiasAddOp GetBiasAdd(Value op) {
76   for (auto &use : op.getUses()) {
77     auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
78     // If it's a BiasAdd, check that the conv op is the first input.
79     if (bias_add && bias_add.value() == op) return bias_add;
80   }
81   // No BiasAddOps found among uses.
82   return BiasAddOp();
83 }
84 
85 // Performs a fusion of the following pattern(s), if possible:
86 //   <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
87 //
88 // Note that fusion with activation is preferred, but a contraction and BiasAdd
89 // can also be replaced by a _FusedConv2D if there is no other activation
90 // function.
91 // i.e., this class also supports the following fusion:
92 //   <Contraction> + BiasAdd -> <FusedContraction>
93 //
94 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
95 template <typename SrcOpT, typename FusedOpT>
96 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
97  public:
98   using OpRewritePattern<SrcOpT>::OpRewritePattern;
99   // Class users should override this method if there are any op-specific
100   // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const101   virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
102                                  PatternRewriter &rewriter) const {
103     return true;
104   }
105 
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const106   LogicalResult matchAndRewrite(SrcOpT contraction,
107                                 PatternRewriter &rewriter) const override {
108     auto context = rewriter.getContext();
109     // If the contraction is used in multiple places, fusing it will only create
110     // more contraction nodes, which is slower.
111     if (!contraction.getResult().hasOneUse())
112       return rewriter.notifyMatchFailure(contraction,
113                                          "result is used by multiple ops");
114 
115     BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
116     if (!bias_add) {
117       return rewriter.notifyMatchFailure(
118           contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
119     }
120 
121     if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
122       return rewriter.notifyMatchFailure(
123           contraction, "cannot fuse with the subsequent BiasAdd op");
124     }
125 
126     SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
127     SmallVector<Attribute, 2> fused_ops{StringAttr::get(
128         context, bias_add.getOperation()->getName().stripDialect())};
129 
130     // BiasAdd may or may not feed into an activation function.
131     auto activation = GetActivation(bias_add);
132 
133     // If there is an activation, only fuse it if this is the only op to use the
134     // result of the BiasAdd.
135     bool fuse_activation = activation && bias_add.output().hasOneUse();
136     Type result_type;
137 
138     // Include info about the activation function if applicable.
139     if (fuse_activation) {
140       locations.push_back(activation->getLoc());
141       fused_ops.push_back(
142           StringAttr::get(context, activation->getName().stripDialect()));
143       result_type = activation->getResultTypes().front();
144     } else {
145       result_type = bias_add.getResult().getType();
146     }
147 
148     auto fused_loc = rewriter.getFusedLoc(locations);
149 
150     // The fused contraction has the same operands as the original contraction
151     // with `bias` from the BiasAddOp appended.
152     SmallVector<Value, 4> operands(contraction.operand_begin(),
153                                    contraction.operand_end());
154     operands.push_back(bias_add.bias());
155 
156     // The fused contraction has the same attributes as the original
157     // contraction, with two additions: the list of ops which have been fused
158     // together; epsilon (only with FusedBatchNorm).
159     std::vector<NamedAttribute> attrs = contraction.getAttrs();
160     ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
161     attrs.push_back(
162         NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
163     // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
164     // here.
165     Attribute epsilon = rewriter.getF32FloatAttr(0);
166     attrs.push_back(
167         NamedAttribute(Identifier::get("epsilon", context), epsilon));
168 
169     // Insert fused operation right before the BiasAdd operation to guarantee
170     // that bias value dominates the fused operation. We already verified that
171     // original operation has a single use, so this is safe to do.
172     auto *bias_add_op = bias_add.getOperation();
173     if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
174 
175     Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
176                                                ValueRange(operands), attrs);
177     auto op_to_replace = fuse_activation ? activation : bias_add;
178     rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
179     return success();
180   }
181 };
182 
183 // Performs a fusion of the following pattern(s), if possible:
184 //   Conv2D + BiasAdd + <Activation> -> _FusedConv2D
185 class FuseConv2DBiasAdd
186     : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
187  public:
188   using FuseContractionWithBiasAdd<Conv2DOp,
189                                    _FusedConv2DOp>::FuseContractionWithBiasAdd;
190   // Verify that the Conv2D and BiasAdd data formats match. This is necessary
191   // for the ops to fuse correctly, the fused Conv2D op has one data format
192   // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const193   bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
194                          PatternRewriter &rewriter) const override {
195     // Verify that the data formats match and are valid for fusion.
196     if (conv.data_format() != bias_add.data_format()) {
197       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
198         diag << "data format does not match Conv2D data format ("
199              << bias_add.data_format() << " vs " << conv.data_format() << ")";
200       });
201       return false;
202     }
203     // Verify the data type is supported.
204     if (!conv.T().isF32() && !conv.T().isF64()) {
205       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
206         diag << "supported data types for _FusedConv2D are float and double, "
207              << " but got " << conv.T();
208       });
209       return false;
210     }
211     return true;
212   }
213 };
214 
215 // Performs a fusion of the following pattern(s), if possible:
216 //   MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
217 class FuseMatMulBiasAdd
218     : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
219   using FuseContractionWithBiasAdd<MatMulOp,
220                                    _FusedMatMulOp>::FuseContractionWithBiasAdd;
221 
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const222   bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
223                          PatternRewriter &rewriter) const override {
224     // FusedMatMul kernel supports limited set of data types.
225     if (!matmul.T().isF32() && !matmul.T().isBF16()) {
226       (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
227         diag << "supported data types for _FusedMatMul are float and bfloat16, "
228              << " but got " << matmul.T();
229       });
230       return false;
231     }
232     return true;
233   }
234 };
235 
runOnFunction()236 void FusedKernelMatcherPass::runOnFunction() {
237   OwningRewritePatternList patterns;
238   auto func = getFunction();
239   patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
240 
241   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
242 }
243 
244 }  // namespace
245 
CreateFusedKernelMatcherPass()246 std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
247   return std::make_unique<FusedKernelMatcherPass>();
248 }
249 
250 static PassRegistration<FusedKernelMatcherPass> pass(
251     "tf-fused-kernel-matcher",
252     "Matches computations corresponding to optimized fused kernels");
253 
254 }  // namespace TF
255 
256 }  // namespace mlir
257