1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Fuse tf.Op + tf.BiasAdd and legalized to TOSA
17
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
30
31 #define PASS_NAME "tosa-fuse-bias-tf"
32 #define DEBUG_TYPE PASS_NAME
33
34 namespace mlir {
35
36 namespace tosa {
37
38 namespace {
39
40 class FuseBiasTF : public PassWrapper<FuseBiasTF, FunctionPass> {
41 public:
FuseBiasTF()42 explicit FuseBiasTF() {}
43 void runOnFunction() override;
44 };
45
46 struct ConvertTFBiasAddOp : public RewritePattern {
ConvertTFBiasAddOpmlir::tosa::__anondf80800b0111::ConvertTFBiasAddOp47 explicit ConvertTFBiasAddOp(MLIRContext* context)
48 : RewritePattern(TF::BiasAddOp::getOperationName(), 1, context) {}
49 LogicalResult matchAndRewrite(Operation* op,
50 PatternRewriter& rewriter) const override;
51 };
52
53 // Replaces the following pattern:
54 // %1 = tf.Conv2D (%ifm, %filter)
55 // %2 = tf.BiasAdd(%1, %bias)
56 // with
57 // %1 = tosa.conv2d(%ifm, %filter, %bias)
58 // This can also be done using the pair ot Pat<> options in
59 // tf_optimize_patterns.td
60 // However, this explicit code can handle both when the LHS or RHS is the
61 // defining conv2d op.
62 // TODO: support other pattern. e.g. tf.DepthwiseConv2DNative
63
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const64 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
65 Operation* op, PatternRewriter& rewriter) const {
66 auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
67
68 auto output_type =
69 tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
70 // Not a ranked tensor output
71 if (!output_type) return failure();
72
73 auto value = tf_biasadd_op.value();
74 auto bias = tf_biasadd_op.bias();
75
76 TF::Conv2DOp tf_conv2d_op =
77 dyn_cast_or_null<TF::Conv2DOp>(value.getDefiningOp());
78
79 if (!tf_conv2d_op) {
80 return failure();
81 }
82
83 // Sanity check to confirm rhs() has the expected shape of bias
84 auto filter_shape =
85 tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>().getShape();
86
87 auto bias_shape = bias.getType().dyn_cast<RankedTensorType>().getShape();
88
89 // Bias dimension must match filter output channels, where tf.conv2d's filter
90 // is [H, W, I, O]
91 if (filter_shape.back() != bias_shape.back()) return failure();
92
93 // Bias tensor that feeds into tosa.conv2d must be rank 1
94 if (bias_shape.size() != 1) return failure();
95
96 auto result = convertTFConv2DCommon(
97 rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
98 bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
99 tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
100 tf_conv2d_op.data_format());
101
102 if (!result) return failure();
103
104 rewriter.replaceOp(op, {result.getValue()});
105
106 return success();
107 }
108
runOnFunction()109 void FuseBiasTF::runOnFunction() {
110 OwningRewritePatternList patterns;
111 auto* ctx = &getContext();
112 auto func = getFunction();
113
114 // Add the generated patterns to the list.
115 patterns.insert<ConvertTFBiasAddOp>(ctx);
116 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
117 }
118
119 } // anonymous namespace
120
createFuseBiasTFPass()121 std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass() {
122 return std::make_unique<FuseBiasTF>();
123 }
124
125 static PassRegistration<FuseBiasTF> pass(
126 PASS_NAME, "Fuse tf.Op + tf.BiasAdd and legalized to TOSA.");
127
128 } // namespace tosa
129
130 } // namespace mlir
131