1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16include "mlir/IR/OpBase.td" 17include "mlir/Dialect/Func/IR/FuncOps.td" 18include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" 19include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 20include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td" 21include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" 22 23//===----------------------------------------------------------------------===// 24// Helper functions. 25//===----------------------------------------------------------------------===// 26 27class IsFusedOpEndsWith<string OpName> : AttrConstraint< 28 CPred<"!$_self.cast<ArrayAttr>().empty() && " 29 "$_self.cast<ArrayAttr>()[$_self.cast<ArrayAttr>().size() - 1]." 30 "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">, 31 "Matching fused '" # OpName # "' op at the end">; 32 33//===----------------------------------------------------------------------===// 34// Pattern rules for lifting ops as functions 35//===----------------------------------------------------------------------===// 36 37def LiftConv : Pat< 38 (TF_Conv2DOp:$res $input, $filter, $strides, $use_cudnn_on_gpu, $padding, 39 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 40 (LiftAsFunctionCall<"composite_conv2d_fn"> 41 (ArgumentList $input, $filter), 42 (ResultList $res), 43 (NamedAttributeList 44 (NamedAttr<"strides"> $strides), 45 (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), 46 (NamedAttr<"padding"> $padding), 47 (NamedAttr<"explicit_paddings"> $explicit_paddings), 48 (NamedAttr<"dilations"> $dilations))), 49 [(IsNotInLiftedFunc $res)], (addBenefit 1)>; 50 51def LiftDepthwiseConv : Pat< 52 (TF_DepthwiseConv2dNativeOp:$res $input, $filter, $strides, $padding, 53 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 54 (LiftAsFunctionCall<"composite_depthwise_conv2d_fn"> 55 (ArgumentList $input, $filter), 56 (ResultList $res), 57 (NamedAttributeList 58 (NamedAttr<"strides"> $strides), 59 (NamedAttr<"padding"> $padding), 60 (NamedAttr<"explicit_paddings"> $explicit_paddings), 61 (NamedAttr<"dilations"> $dilations))), 62 [(IsNotInLiftedFunc $res)], (addBenefit 1)>; 63 64def LiftMatMul : Pat< 65 (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b), 66 (LiftAsFunctionCall<"composite_matmul_fn"> 67 (ArgumentList $a, $b), 68 (ResultList $res), 69 (NamedAttributeList 70 (NamedAttr<"transpose_a"> $transpose_a), 71 (NamedAttr<"transpose_b"> $transpose_b))), 72 [(IsNotInLiftedFunc $res)], (addBenefit 1)>; 73 74//===----------------------------------------------------------------------===// 75// Pattern rules for lifting ops with bias as functions 76//===----------------------------------------------------------------------===// 77 78def LiftDepthwiseConv2dNativeWithBias : Pat< 79 (TF_BiasAddOp:$res 80 (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, 81 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 82 $bias, IsDataFormatNHWC:$bias_data_format), 83 (LiftAsFunctionCall<"composite_depthwise_conv2d_with_bias_fn"> 84 (ArgumentList $input, $filter, $bias), 85 (ResultList $res), 86 (NamedAttributeList 87 (NamedAttr<"strides"> $strides), 88 (NamedAttr<"padding"> $padding), 89 (NamedAttr<"explicit_paddings"> $explicit_paddings), 90 (NamedAttr<"dilations"> $dilations))), 91 [(IsNotInLiftedFunc $res)], (addBenefit 5)>; 92 93def LiftConv2dWithBias : Pat< 94 (TF_BiasAddOp:$res 95 (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, 96 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 97 $bias, IsDataFormatNHWC:$bias_data_format), 98 (LiftAsFunctionCall<"composite_conv2d_with_bias_fn"> 99 (ArgumentList $input, $filter, $bias), 100 (ResultList $res), 101 (NamedAttributeList 102 (NamedAttr<"strides"> $strides), 103 (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), 104 (NamedAttr<"padding"> $padding), 105 (NamedAttr<"explicit_paddings"> $explicit_paddings), 106 (NamedAttr<"dilations"> $dilations))), 107 [(IsNotInLiftedFunc $res)], (addBenefit 5)>; 108 109def LiftMatmulWithBias : Pat< 110 (TF_BiasAddOp:$res 111 (TF_MatMulOp $a, $b, $transpose_a, $transpose_b), 112 $bias, IsDataFormatNHWC:$bias_data_format), 113 (LiftAsFunctionCall<"composite_matmul_with_bias_fn"> 114 (ArgumentList $a, $b, $bias), 115 (ResultList $res), 116 (NamedAttributeList 117 (NamedAttr<"transpose_a"> $transpose_a), 118 (NamedAttr<"transpose_b"> $transpose_b))), 119 [(IsNotInLiftedFunc $res)], (addBenefit 5)>; 120 121//===----------------------------------------------------------------------===// 122// Pattern rules for lifting ops with bias and activation as functions 123//===----------------------------------------------------------------------===// 124 125multiclass LiftCompositeOpsWithActivation<Op ActivationOp, string ActivationName> { 126 def LiftConvWith#ActivationOp : Pat< 127 (ActivationOp:$res 128 (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, 129 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)), 130 (LiftAsFunctionCall<"composite_conv2d_with_"# ActivationName #"_fn"> 131 (ArgumentList $input, $filter), 132 (ResultList $res), 133 (NamedAttributeList 134 (NamedAttr<"strides"> $strides), 135 (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), 136 (NamedAttr<"padding"> $padding), 137 (NamedAttr<"explicit_paddings"> $explicit_paddings), 138 (NamedAttr<"dilations"> $dilations))), 139 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 140 141 def LiftConv2dWithBiasAnd#LastFusedOp : Pat< 142 (ActivationOp:$res 143 (TF_BiasAddOp 144 (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, 145 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 146 $bias, IsDataFormatNHWC:$bias_data_format)), 147 (LiftAsFunctionCall<"composite_conv2d_with_bias_and_"# ActivationName #"_fn"> 148 (ArgumentList $input, $filter, $bias), 149 (ResultList $res), 150 (NamedAttributeList 151 (NamedAttr<"strides"> $strides), 152 (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), 153 (NamedAttr<"padding"> $padding), 154 (NamedAttr<"explicit_paddings"> $explicit_paddings), 155 (NamedAttr<"dilations"> $dilations))), 156 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 157 158 def LiftDepthwiseConv2dNativeWith#ActivationOp : Pat< 159 (ActivationOp:$res 160 (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, 161 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)), 162 (LiftAsFunctionCall<"composite_depthwise_conv2d_with_"# ActivationName #"_fn"> 163 (ArgumentList $input, $filter), 164 (ResultList $res), 165 (NamedAttributeList 166 (NamedAttr<"strides"> $strides), 167 (NamedAttr<"padding"> $padding), 168 (NamedAttr<"explicit_paddings"> $explicit_paddings), 169 (NamedAttr<"dilations"> $dilations))), 170 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 171 172 def LiftDepthwiseConv2dNativeWithBiasAnd#LastFusedOp : Pat< 173 (ActivationOp:$res 174 (TF_BiasAddOp 175 (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, 176 $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), 177 $bias, IsDataFormatNHWC:$bias_data_format)), 178 (LiftAsFunctionCall<"composite_depthwise_conv2d_with_bias_and_"# ActivationName #"_fn"> 179 (ArgumentList $input, $filter, $bias), 180 (ResultList $res), 181 (NamedAttributeList 182 (NamedAttr<"strides"> $strides), 183 (NamedAttr<"padding"> $padding), 184 (NamedAttr<"explicit_paddings"> $explicit_paddings), 185 (NamedAttr<"dilations"> $dilations))), 186 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 187 188 def LiftMatmulWith#ActivationOp : Pat< 189 (ActivationOp:$res 190 (TF_MatMulOp $a, $b, $transpose_a, $transpose_b)), 191 (LiftAsFunctionCall<"composite_matmul_with_"# ActivationName #"_fn"> 192 (ArgumentList $a, $b), 193 (ResultList $res), 194 (NamedAttributeList 195 (NamedAttr<"transpose_a"> $transpose_a), 196 (NamedAttr<"transpose_b"> $transpose_b))), 197 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 198 199 def LiftMatmulWithBiasAnd#LastFusedOp : Pat< 200 (ActivationOp:$res 201 (TF_BiasAddOp 202 (TF_MatMulOp $a, $b, $transpose_a, $transpose_b), 203 $bias, IsDataFormatNHWC:$bias_data_format)), 204 (LiftAsFunctionCall<"composite_matmul_with_bias_and_"# ActivationName #"_fn"> 205 (ArgumentList $a, $b, $bias), 206 (ResultList $res), 207 (NamedAttributeList 208 (NamedAttr<"transpose_a"> $transpose_a), 209 (NamedAttr<"transpose_b"> $transpose_b))), 210 [(IsNotInLiftedFunc $res)], (addBenefit 10)>; 211 212} 213defm : LiftCompositeOpsWithActivation<TF_ReluOp, "relu">; 214defm : LiftCompositeOpsWithActivation<TF_Relu6Op, "relu6">; 215