• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdio>
17 #include <iostream>
18 
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 
30 namespace mlir {
31 
32 namespace TF {
33 
34 namespace {
35 
36 // Note: This implements the fusions performed in the old Remapper Grappler
37 // pass. That pass has specific cases for GPU and based on different
38 // target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
39 // pass covers (some of) the general CPU case and at the moment does not account
40 // for any target-specific configurations.
41 
42 // This pass is being ported over from the Grappler Remapper pass based on
43 // need/usage. File a bug to request porting over additional fusions.
44 
45 // TODO(b/158265178): Support GPU-specific fusions.
46 // TODO(b/158266710): Support CPU MKL configurations.
47 
48 // Optimizes TF computations by fusing subgraphs/nodes onto more efficient
49 // implementations to decrease the number of operations needed to perform a
50 // computation.
51 struct FusedKernelMatcherPass
52     : public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
getArgumentmlir::TF::__anon18d97fbc0111::FusedKernelMatcherPass53   StringRef getArgument() const final { return "tf-fused-kernel-matcher"; }
54 
getDescriptionmlir::TF::__anon18d97fbc0111::FusedKernelMatcherPass55   StringRef getDescription() const final {
56     return "Matches computations corresponding to optimized fused kernels";
57   }
58 
59   void runOnFunction() override;
60 };
61 
IsActivationFunction(Operation * op)62 bool IsActivationFunction(Operation *op) {
63   return isa<EluOp, ReluOp, Relu6Op>(op);
64 }
65 
66 // Finds and returns an activation op that uses the result of `op`. If there are
67 // multiple such activations, one is returned (with no guarantee as to which
68 // one). If there are no activation functions that use the output, returns
69 // nullptr.
GetActivation(Value op)70 Operation *GetActivation(Value op) {
71   for (auto &use : op.getUses()) {
72     if (IsActivationFunction(use.getOwner())) return use.getOwner();
73   }
74   return nullptr;
75 }
76 
77 // Finds and returns a BiasAdd that uses the result of `op` as the `value`
78 // input. If there are multiple such BiasAdds, one is returned (with no
79 // guarantee as to which one). If there are no BiasAdds that use the output,
80 // returns a null BiasAddOp.
GetBiasAdd(Value op)81 BiasAddOp GetBiasAdd(Value op) {
82   for (auto &use : op.getUses()) {
83     auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
84     // If it's a BiasAdd, check that the conv op is the first input.
85     if (bias_add && bias_add.value() == op) return bias_add;
86   }
87   // No BiasAddOps found among uses.
88   return BiasAddOp();
89 }
90 
91 // Performs a fusion of the following pattern(s), if possible:
92 //   <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
93 //
94 // Note that fusion with activation is preferred, but a contraction and BiasAdd
95 // can also be replaced by a _FusedConv2D if there is no other activation
96 // function.
97 // i.e., this class also supports the following fusion:
98 //   <Contraction> + BiasAdd -> <FusedContraction>
99 //
100 // TODO(b/158266331): Support fusing activation chains of arbitrary length.
101 template <typename SrcOpT, typename FusedOpT>
102 class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
103  public:
104   using OpRewritePattern<SrcOpT>::OpRewritePattern;
105   // Class users should override this method if there are any op-specific
106   // compatibility requirements between the contraction op and the BiasAdd op.
AreFuseCompatible(SrcOpT contraction_op,BiasAddOp bias_add,PatternRewriter & rewriter) const107   virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
108                                  PatternRewriter &rewriter) const {
109     return true;
110   }
111 
matchAndRewrite(SrcOpT contraction,PatternRewriter & rewriter) const112   LogicalResult matchAndRewrite(SrcOpT contraction,
113                                 PatternRewriter &rewriter) const override {
114     auto context = rewriter.getContext();
115 
116     // We do support fusion only if the contraction operation is inside one of
117     // the expected operations with regions. Other operations can have semantics
118     // that is not compatible with fusion (e.g. region compilation).
119     if (!isa<FuncOp, IfOp, WhileOp>(contraction->getParentOp())) {
120       return rewriter.notifyMatchFailure(
121           contraction,
122           "fused operation must be nested inside a function, If or While");
123     }
124 
125     // If the contraction is used in multiple places, fusing it will only create
126     // more contraction nodes, which is slower.
127     if (!contraction.getResult().hasOneUse())
128       return rewriter.notifyMatchFailure(contraction,
129                                          "result is used by multiple ops");
130 
131     BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
132     if (!bias_add) {
133       return rewriter.notifyMatchFailure(
134           contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
135     }
136 
137     if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
138       return rewriter.notifyMatchFailure(
139           contraction, "cannot fuse with the subsequent BiasAdd op");
140     }
141 
142     SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
143     SmallVector<Attribute, 2> fused_ops{StringAttr::get(
144         context, bias_add.getOperation()->getName().stripDialect())};
145 
146     // BiasAdd may or may not feed into an activation function.
147     auto activation = GetActivation(bias_add);
148 
149     // If there is an activation, only fuse it if this is the only op to use the
150     // result of the BiasAdd.
151     bool fuse_activation = activation && bias_add.output().hasOneUse();
152     Type result_type;
153 
154     // Include info about the activation function if applicable.
155     if (fuse_activation) {
156       locations.push_back(activation->getLoc());
157       fused_ops.push_back(
158           StringAttr::get(context, activation->getName().stripDialect()));
159       result_type = activation->getResultTypes().front();
160     } else {
161       result_type = bias_add.getResult().getType();
162     }
163 
164     auto fused_loc = rewriter.getFusedLoc(locations);
165 
166     // The fused contraction has the same operands as the original contraction
167     // with `bias` from the BiasAddOp appended.
168     SmallVector<Value, 4> operands(contraction.operand_begin(),
169                                    contraction.operand_end());
170     operands.push_back(bias_add.bias());
171 
172     // The fused contraction has the same attributes as the original
173     // contraction, with two additions: the list of ops which have been fused
174     // together; epsilon (only with FusedBatchNorm).
175     std::vector<NamedAttribute> attrs = contraction->getAttrs();
176     ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
177     attrs.push_back(
178         NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
179     // Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
180     // here.
181     Attribute epsilon = rewriter.getF32FloatAttr(0);
182     attrs.push_back(
183         NamedAttribute(Identifier::get("epsilon", context), epsilon));
184 
185     // Insert fused operation right before the BiasAdd operation to guarantee
186     // that bias value dominates the fused operation. We already verified that
187     // original operation has a single use, so this is safe to do.
188     auto *bias_add_op = bias_add.getOperation();
189     if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
190 
191     Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
192                                                ValueRange(operands), attrs);
193     auto op_to_replace = fuse_activation ? activation : bias_add;
194     rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
195     return success();
196   }
197 };
198 
199 // Performs a fusion of the following pattern(s), if possible:
200 //   Conv2D + BiasAdd + <Activation> -> _FusedConv2D
201 class FuseConv2DBiasAdd
202     : public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
203  public:
204   using FuseContractionWithBiasAdd<Conv2DOp,
205                                    _FusedConv2DOp>::FuseContractionWithBiasAdd;
206   // Verify that the Conv2D and BiasAdd data formats match. This is necessary
207   // for the ops to fuse correctly, the fused Conv2D op has one data format
208   // attribute which is shared.
AreFuseCompatible(Conv2DOp conv,BiasAddOp bias_add,PatternRewriter & rewriter) const209   bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
210                          PatternRewriter &rewriter) const override {
211     // Verify that the data formats match and are valid for fusion.
212     if (conv.data_format() != bias_add.data_format()) {
213       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
214         diag << "data format does not match Conv2D data format ("
215              << bias_add.data_format() << " vs " << conv.data_format() << ")";
216       });
217       return false;
218     }
219     // Verify the data type is supported.
220     if (!conv.T().isF32() && !conv.T().isF64()) {
221       (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
222         diag << "supported data types for _FusedConv2D are float and double, "
223              << " but got " << conv.T();
224       });
225       return false;
226     }
227     return true;
228   }
229 };
230 
231 // Performs a fusion of the following pattern(s), if possible:
232 //   MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
233 class FuseMatMulBiasAdd
234     : public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
235   using FuseContractionWithBiasAdd<MatMulOp,
236                                    _FusedMatMulOp>::FuseContractionWithBiasAdd;
237 
AreFuseCompatible(MatMulOp matmul,BiasAddOp bias_add,PatternRewriter & rewriter) const238   bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
239                          PatternRewriter &rewriter) const override {
240     // FusedMatMul kernel supports limited set of data types.
241     if (!matmul.T().isF32() && !matmul.T().isBF16()) {
242       (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
243         diag << "supported data types for _FusedMatMul are float and bfloat16, "
244              << " but got " << matmul.T();
245       });
246       return false;
247     }
248     return true;
249   }
250 };
251 
runOnFunction()252 void FusedKernelMatcherPass::runOnFunction() {
253   OwningRewritePatternList patterns(&getContext());
254   auto func = getFunction();
255   patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
256 
257   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
258 }
259 
260 }  // namespace
261 
CreateFusedKernelMatcherPass()262 std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
263   return std::make_unique<FusedKernelMatcherPass>();
264 }
265 
266 static PassRegistration<FusedKernelMatcherPass> pass;
267 
268 }  // namespace TF
269 
270 }  // namespace mlir
271