1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ 18 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 20 #include "mlir/IR/PatternMatch.h" // from @llvm-project 21 #include "mlir/Support/LogicalResult.h" // from @llvm-project 22 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 24 25 namespace mlir { 26 namespace TFL { 27 namespace tac { 28 29 // TODO(renjieliu): add more patterns. 30 31 // This basically: 32 // Pack => (Concat -> Reshape) 33 struct LowerPackIntoConcatReshape : public OpRewritePattern<TFL::PackOp> { 34 using OpRewritePattern<TFL::PackOp>::OpRewritePattern; 35 36 LogicalResult matchAndRewrite(TFL::PackOp pack_op, 37 PatternRewriter& rewriter) const override; 38 }; 39 40 struct SquaredDifference : public OpRewritePattern<TFL::SquaredDifferenceOp> { 41 using OpRewritePattern<TFL::SquaredDifferenceOp>::OpRewritePattern; 42 43 LogicalResult matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op, 44 PatternRewriter& rewriter) const override; 45 }; 46 47 // Unroll split into a bunch of slice ops. 48 struct UnrollSplit : public OpRewritePattern<TFL::SplitOp> { 49 using OpRewritePattern<TFL::SplitOp>::OpRewritePattern; 50 51 LogicalResult matchAndRewrite(TFL::SplitOp split_op, 52 PatternRewriter& rewriter) const override; 53 }; 54 55 // Unroll splitv into a bunch of slice ops. 56 struct UnrollSplitV : public OpRewritePattern<TFL::SplitVOp> { 57 using OpRewritePattern<TFL::SplitVOp>::OpRewritePattern; 58 59 LogicalResult matchAndRewrite(TFL::SplitVOp splitv_op, 60 PatternRewriter& rewriter) const override; 61 }; 62 63 // Ensure bias for conv2d op. 64 struct EnsureBiasForConv2d : public OpRewritePattern<TFL::Conv2DOp> { 65 using OpRewritePattern<TFL::Conv2DOp>::OpRewritePattern; 66 67 LogicalResult matchAndRewrite(TFL::Conv2DOp conv_op, 68 PatternRewriter& rewriter) const override; 69 }; 70 71 // Pad slice to 4d. 72 struct PadSlice : public OpRewritePattern<TFL::SliceOp> { 73 using OpRewritePattern<TFL::SliceOp>::OpRewritePattern; 74 75 LogicalResult matchAndRewrite(TFL::SliceOp slice_op, 76 PatternRewriter& rewriter) const override; 77 }; 78 79 // Fully connected to conv2d. 80 struct FullyConnectedToConv : public OpRewritePattern<TFL::FullyConnectedOp> { 81 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern; 82 83 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op, 84 PatternRewriter& rewriter) const override; 85 }; 86 87 // Pad concat to 4d. 88 struct PadConcat : public OpRewritePattern<TFL::ConcatenationOp> { 89 using OpRewritePattern<TFL::ConcatenationOp>::OpRewritePattern; 90 91 LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op, 92 PatternRewriter& rewriter) const override; 93 }; 94 95 } // namespace tac 96 } // namespace TFL 97 } // namespace mlir 98 99 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ 100