• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Fuse tf.Op + tf.BiasAdd and legalized to TOSA
17 
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
30 
31 #define PASS_NAME "tosa-fuse-bias-tf"
32 #define DEBUG_TYPE PASS_NAME
33 
34 namespace mlir {
35 
36 namespace tosa {
37 
38 namespace {
39 
40 class FuseBiasTF : public PassWrapper<FuseBiasTF, FunctionPass> {
41  public:
FuseBiasTF()42   explicit FuseBiasTF() {}
43   void runOnFunction() override;
44 };
45 
46 struct ConvertTFBiasAddOp : public RewritePattern {
ConvertTFBiasAddOpmlir::tosa::__anondf80800b0111::ConvertTFBiasAddOp47   explicit ConvertTFBiasAddOp(MLIRContext* context)
48       : RewritePattern(TF::BiasAddOp::getOperationName(), 1, context) {}
49   LogicalResult matchAndRewrite(Operation* op,
50                                 PatternRewriter& rewriter) const override;
51 };
52 
53 // Replaces the following pattern:
54 //   %1 = tf.Conv2D (%ifm, %filter)
55 //   %2 = tf.BiasAdd(%1, %bias)
56 //   with
57 //   %1 = tosa.conv2d(%ifm, %filter, %bias)
58 //   This can also be done using the pair ot Pat<> options in
59 //   tf_optimize_patterns.td
60 //   However, this explicit code can handle both when the LHS or RHS is the
61 //   defining conv2d op.
62 // TODO: support other pattern. e.g. tf.DepthwiseConv2DNative
63 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const64 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
65     Operation* op, PatternRewriter& rewriter) const {
66   auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
67 
68   auto output_type =
69       tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
70   // Not a ranked tensor output
71   if (!output_type) return failure();
72 
73   auto value = tf_biasadd_op.value();
74   auto bias = tf_biasadd_op.bias();
75 
76   TF::Conv2DOp tf_conv2d_op =
77       dyn_cast_or_null<TF::Conv2DOp>(value.getDefiningOp());
78 
79   if (!tf_conv2d_op) {
80     return failure();
81   }
82 
83   // Sanity check to confirm rhs() has the expected shape of bias
84   auto filter_shape =
85       tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>().getShape();
86 
87   auto bias_shape = bias.getType().dyn_cast<RankedTensorType>().getShape();
88 
89   // Bias dimension must match filter output channels, where tf.conv2d's filter
90   // is [H, W, I, O]
91   if (filter_shape.back() != bias_shape.back()) return failure();
92 
93   // Bias tensor that feeds into tosa.conv2d must be rank 1
94   if (bias_shape.size() != 1) return failure();
95 
96   auto result = convertTFConv2DCommon(
97       rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
98       bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
99       tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
100       tf_conv2d_op.data_format());
101 
102   if (!result) return failure();
103 
104   rewriter.replaceOp(op, {result.getValue()});
105 
106   return success();
107 }
108 
runOnFunction()109 void FuseBiasTF::runOnFunction() {
110   OwningRewritePatternList patterns;
111   auto* ctx = &getContext();
112   auto func = getFunction();
113 
114   // Add the generated patterns to the list.
115   patterns.insert<ConvertTFBiasAddOp>(ctx);
116   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
117 }
118 
119 }  // anonymous namespace
120 
createFuseBiasTFPass()121 std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass() {
122   return std::make_unique<FuseBiasTF>();
123 }
124 
125 static PassRegistration<FuseBiasTF> pass(
126     PASS_NAME, "Fuse tf.Op + tf.BiasAdd and legalized to TOSA.");
127 
128 }  // namespace tosa
129 
130 }  // namespace mlir
131