• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16include "mlir/IR/OpBase.td"
17include "mlir/Dialect/Func/IR/FuncOps.td"
18include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
19include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
20include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td"
21include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td"
22
23//===----------------------------------------------------------------------===//
24// Helper functions.
25//===----------------------------------------------------------------------===//
26
27class IsFusedOpEndsWith<string OpName> : AttrConstraint<
28  CPred<"!$_self.cast<ArrayAttr>().empty() && "
29        "$_self.cast<ArrayAttr>()[$_self.cast<ArrayAttr>().size() - 1]."
30        "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">,
31  "Matching fused '" # OpName # "' op at the end">;
32
33//===----------------------------------------------------------------------===//
34// Pattern rules for lifting ops as functions
35//===----------------------------------------------------------------------===//
36
37def LiftConv : Pat<
38  (TF_Conv2DOp:$res $input, $filter, $strides, $use_cudnn_on_gpu, $padding,
39    $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
40  (LiftAsFunctionCall<"composite_conv2d_fn">
41    (ArgumentList $input, $filter),
42    (ResultList $res),
43    (NamedAttributeList
44      (NamedAttr<"strides"> $strides),
45      (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu),
46      (NamedAttr<"padding"> $padding),
47      (NamedAttr<"explicit_paddings"> $explicit_paddings),
48      (NamedAttr<"dilations"> $dilations))),
49  [(IsNotInLiftedFunc $res)], (addBenefit 1)>;
50
51def LiftDepthwiseConv : Pat<
52  (TF_DepthwiseConv2dNativeOp:$res $input, $filter, $strides, $padding,
53    $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
54  (LiftAsFunctionCall<"composite_depthwise_conv2d_fn">
55    (ArgumentList $input, $filter),
56    (ResultList $res),
57    (NamedAttributeList
58      (NamedAttr<"strides"> $strides),
59      (NamedAttr<"padding"> $padding),
60      (NamedAttr<"explicit_paddings"> $explicit_paddings),
61      (NamedAttr<"dilations"> $dilations))),
62  [(IsNotInLiftedFunc $res)], (addBenefit 1)>;
63
64def LiftMatMul : Pat<
65  (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b),
66  (LiftAsFunctionCall<"composite_matmul_fn">
67    (ArgumentList $a, $b),
68    (ResultList $res),
69    (NamedAttributeList
70      (NamedAttr<"transpose_a"> $transpose_a),
71      (NamedAttr<"transpose_b"> $transpose_b))),
72  [(IsNotInLiftedFunc $res)], (addBenefit 1)>;
73
74//===----------------------------------------------------------------------===//
75// Pattern rules for lifting ops with bias as functions
76//===----------------------------------------------------------------------===//
77
78def LiftDepthwiseConv2dNativeWithBias : Pat<
79  (TF_BiasAddOp:$res
80    (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding,
81      $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
82    $bias, IsDataFormatNHWC:$bias_data_format),
83  (LiftAsFunctionCall<"composite_depthwise_conv2d_with_bias_fn">
84    (ArgumentList $input, $filter, $bias),
85    (ResultList $res),
86    (NamedAttributeList
87      (NamedAttr<"strides"> $strides),
88      (NamedAttr<"padding"> $padding),
89      (NamedAttr<"explicit_paddings"> $explicit_paddings),
90      (NamedAttr<"dilations"> $dilations))),
91  [(IsNotInLiftedFunc $res)], (addBenefit 5)>;
92
93def LiftConv2dWithBias : Pat<
94  (TF_BiasAddOp:$res
95    (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding,
96      $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
97    $bias, IsDataFormatNHWC:$bias_data_format),
98  (LiftAsFunctionCall<"composite_conv2d_with_bias_fn">
99    (ArgumentList $input, $filter, $bias),
100    (ResultList $res),
101    (NamedAttributeList
102      (NamedAttr<"strides"> $strides),
103      (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu),
104      (NamedAttr<"padding"> $padding),
105      (NamedAttr<"explicit_paddings"> $explicit_paddings),
106      (NamedAttr<"dilations"> $dilations))),
107  [(IsNotInLiftedFunc $res)], (addBenefit 5)>;
108
109def LiftMatmulWithBias : Pat<
110  (TF_BiasAddOp:$res
111    (TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
112    $bias, IsDataFormatNHWC:$bias_data_format),
113  (LiftAsFunctionCall<"composite_matmul_with_bias_fn">
114    (ArgumentList $a, $b, $bias),
115    (ResultList $res),
116    (NamedAttributeList
117      (NamedAttr<"transpose_a"> $transpose_a),
118      (NamedAttr<"transpose_b"> $transpose_b))),
119  [(IsNotInLiftedFunc $res)], (addBenefit 5)>;
120
121//===----------------------------------------------------------------------===//
122// Pattern rules for lifting ops with bias and activation as functions
123//===----------------------------------------------------------------------===//
124
125multiclass LiftCompositeOpsWithActivation<Op ActivationOp, string ActivationName> {
126  def LiftConvWith#ActivationOp : Pat<
127    (ActivationOp:$res
128      (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding,
129        $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)),
130    (LiftAsFunctionCall<"composite_conv2d_with_"# ActivationName #"_fn">
131      (ArgumentList $input, $filter),
132      (ResultList $res),
133      (NamedAttributeList
134        (NamedAttr<"strides"> $strides),
135        (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu),
136        (NamedAttr<"padding"> $padding),
137        (NamedAttr<"explicit_paddings"> $explicit_paddings),
138        (NamedAttr<"dilations"> $dilations))),
139    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
140
141  def LiftConv2dWithBiasAnd#LastFusedOp : Pat<
142    (ActivationOp:$res
143      (TF_BiasAddOp
144        (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding,
145          $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
146        $bias, IsDataFormatNHWC:$bias_data_format)),
147    (LiftAsFunctionCall<"composite_conv2d_with_bias_and_"# ActivationName #"_fn">
148      (ArgumentList $input, $filter, $bias),
149      (ResultList $res),
150      (NamedAttributeList
151        (NamedAttr<"strides"> $strides),
152        (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu),
153        (NamedAttr<"padding"> $padding),
154        (NamedAttr<"explicit_paddings"> $explicit_paddings),
155        (NamedAttr<"dilations"> $dilations))),
156    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
157
158  def LiftDepthwiseConv2dNativeWith#ActivationOp : Pat<
159    (ActivationOp:$res
160      (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding,
161        $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)),
162    (LiftAsFunctionCall<"composite_depthwise_conv2d_with_"# ActivationName #"_fn">
163      (ArgumentList $input, $filter),
164      (ResultList $res),
165      (NamedAttributeList
166        (NamedAttr<"strides"> $strides),
167        (NamedAttr<"padding"> $padding),
168        (NamedAttr<"explicit_paddings"> $explicit_paddings),
169        (NamedAttr<"dilations"> $dilations))),
170    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
171
172  def LiftDepthwiseConv2dNativeWithBiasAnd#LastFusedOp : Pat<
173    (ActivationOp:$res
174      (TF_BiasAddOp
175        (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding,
176          $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations),
177        $bias, IsDataFormatNHWC:$bias_data_format)),
178    (LiftAsFunctionCall<"composite_depthwise_conv2d_with_bias_and_"# ActivationName #"_fn">
179      (ArgumentList $input, $filter, $bias),
180      (ResultList $res),
181      (NamedAttributeList
182        (NamedAttr<"strides"> $strides),
183        (NamedAttr<"padding"> $padding),
184        (NamedAttr<"explicit_paddings"> $explicit_paddings),
185        (NamedAttr<"dilations"> $dilations))),
186    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
187
188  def LiftMatmulWith#ActivationOp : Pat<
189    (ActivationOp:$res
190      (TF_MatMulOp $a, $b, $transpose_a, $transpose_b)),
191    (LiftAsFunctionCall<"composite_matmul_with_"# ActivationName #"_fn">
192      (ArgumentList $a, $b),
193      (ResultList $res),
194      (NamedAttributeList
195        (NamedAttr<"transpose_a"> $transpose_a),
196        (NamedAttr<"transpose_b"> $transpose_b))),
197    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
198
199  def LiftMatmulWithBiasAnd#LastFusedOp : Pat<
200    (ActivationOp:$res
201      (TF_BiasAddOp
202        (TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
203      $bias, IsDataFormatNHWC:$bias_data_format)),
204    (LiftAsFunctionCall<"composite_matmul_with_bias_and_"# ActivationName #"_fn">
205      (ArgumentList $a, $b, $bias),
206      (ResultList $res),
207      (NamedAttributeList
208        (NamedAttr<"transpose_a"> $transpose_a),
209        (NamedAttr<"transpose_b"> $transpose_b))),
210    [(IsNotInLiftedFunc $res)], (addBenefit 10)>;
211
212}
213defm : LiftCompositeOpsWithActivation<TF_ReluOp, "relu">;
214defm : LiftCompositeOpsWithActivation<TF_Relu6Op, "relu6">;
215