• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_
18 
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24 
25 namespace mlir {
26 namespace TFL {
27 namespace tac {
28 
29 // TODO(renjieliu): add more patterns.
30 
31 // This basically:
32 // Pack => (Concat -> Reshape)
33 struct LowerPackIntoConcatReshape : public OpRewritePattern<TFL::PackOp> {
34   using OpRewritePattern<TFL::PackOp>::OpRewritePattern;
35 
36   LogicalResult matchAndRewrite(TFL::PackOp pack_op,
37                                 PatternRewriter& rewriter) const override;
38 };
39 
40 struct SquaredDifference : public OpRewritePattern<TFL::SquaredDifferenceOp> {
41   using OpRewritePattern<TFL::SquaredDifferenceOp>::OpRewritePattern;
42 
43   LogicalResult matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op,
44                                 PatternRewriter& rewriter) const override;
45 };
46 
47 // Unroll split into a bunch of slice ops.
48 struct UnrollSplit : public OpRewritePattern<TFL::SplitOp> {
49   using OpRewritePattern<TFL::SplitOp>::OpRewritePattern;
50 
51   LogicalResult matchAndRewrite(TFL::SplitOp split_op,
52                                 PatternRewriter& rewriter) const override;
53 };
54 
55 // Unroll splitv into a bunch of slice ops.
56 struct UnrollSplitV : public OpRewritePattern<TFL::SplitVOp> {
57   using OpRewritePattern<TFL::SplitVOp>::OpRewritePattern;
58 
59   LogicalResult matchAndRewrite(TFL::SplitVOp splitv_op,
60                                 PatternRewriter& rewriter) const override;
61 };
62 
63 // Ensure bias for conv2d op.
64 struct EnsureBiasForConv2d : public OpRewritePattern<TFL::Conv2DOp> {
65   using OpRewritePattern<TFL::Conv2DOp>::OpRewritePattern;
66 
67   LogicalResult matchAndRewrite(TFL::Conv2DOp conv_op,
68                                 PatternRewriter& rewriter) const override;
69 };
70 
71 // Pad slice to 4d.
72 struct PadSlice : public OpRewritePattern<TFL::SliceOp> {
73   using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
74 
75   LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
76                                 PatternRewriter& rewriter) const override;
77 };
78 
79 // Fully connected to conv2d.
80 struct FullyConnectedToConv : public OpRewritePattern<TFL::FullyConnectedOp> {
81   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
82 
83   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
84                                 PatternRewriter& rewriter) const override;
85 };
86 
87 // Pad concat to 4d.
88 struct PadConcat : public OpRewritePattern<TFL::ConcatenationOp> {
89   using OpRewritePattern<TFL::ConcatenationOp>::OpRewritePattern;
90 
91   LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
92                                 PatternRewriter& rewriter) const override;
93 };
94 
95 }  // namespace tac
96 }  // namespace TFL
97 }  // namespace mlir
98 
99 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_
100