1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the optimization pattern definition file for TensorFlow Lite. 17 18include "mlir/IR/OpBase.td" 19include "mlir/Dialect/StandardOps/IR/Ops.td" 20include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 21include "tensorflow/compiler/mlir/lite/utils/utils.td" 22include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 23 24// Checks if the param passed is a F32 ElementsAttr. 25def F32ElementsAttr : ElementsAttrBase< 26 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">, 27 "32 bit float constant tensor">; 28 29// Checks if the param passed is a float ElementsAttr. 30def FloatElementsAttr : ElementsAttrBase< 31 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">, 32 "float constant tensor">; 33 34// Checks if the param passed is of NoneType. 35def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>; 36 37def ExtractSingleElementAsFloat : NativeCodeCall< 38 "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">; 39 40// Checks if the value has rank at most 'n'. 41class HasRankAtMost<int n> : Constraint< 42 CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>; 43 44//===----------------------------------------------------------------------===// 45// Ternary ops patterns. 46//===----------------------------------------------------------------------===// 47// Multi-pattern consisting of matching stand-alone convolution op followed by 48// activation op. 49multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> { 50 def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat< 51 (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor, 52 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)), 53 (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr, 54 $padding, $stride_h, $stride_w), 55 [(HasOneUse $conv_out)]>; 56 def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat< 57 (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor, 58 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, 59 $multiplier)), 60 (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor, 61 ActFnAttr, $padding, $stride_h, $stride_w, $multiplier), 62 [(HasOneUse $conv_out)]>; 63} 64 65multiclass FuseActFnIntoPoolOpPat<dag ActFnOp, dag ActFnAttr> { 66 def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat< 67 (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height, 68 $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)), 69 (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding, 70 $stride_h, $stride_w, ActFnAttr), 71 [(HasOneUse $pool_out)]>; 72 def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat< 73 (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h, 74 $filter_width, $filter_height, TFL_AF_None)), 75 (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h, 76 $filter_width, $filter_height, ActFnAttr), 77 [(HasOneUse $pool_out)]>; 78} 79 80// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused 81// activation functions. 82// Currently we're not fusing tanh, sigmoid, hard_swish and other activations 83// those cannot be simply translated into clamping. 84foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 85 [TFL_Relu6Op, TFL_AF_Relu6], 86 [TFL_Relu1Op, TFL_AF_Relu1]] in { 87 defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>; 88 defm : FuseActFnIntoPoolOpPat<actFnPair[0], actFnPair[1]>; 89} 90 91class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint< 92 CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; 93 94// If we see a binary op (add, sub) op adding a constant value to a convolution 95// op with constant bias, we can fuse the binary op into the convolution op by 96// constant folding the bias and the binary op's constant operand. The following 97// pattern restricts to float constant values for now. 98multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> { 99 def FuseBinaryOpWithConv#binaryOp : Pat< 100 (binaryOp (TFL_Conv2DOp:$output $input, $filter, 101 (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, 102 TFL_AF_None, $padding, $stride_h, $stride_w), 103 (ConstantOp FloatElementsAttr:$value), $act_fn), 104 (TFL_Conv2DOp $input, $filter, 105 (binaryOp (ConstantOp $bias), 106 (ConstantOp $value), TFL_AF_None), 107 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 108 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 109 (HasOneUse $output)]>; 110 def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat< 111 (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, 112 (ConstantOp FloatElementsAttr:$bias), 113 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 114 $stride_w, $multiplier), 115 (ConstantOp FloatElementsAttr:$value), $act_fn), 116 (TFL_DepthwiseConv2DOp $input, $filter, 117 (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), 118 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, 119 $multiplier), 120 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 121 (HasOneUse $output)]>; 122 def FuseBinaryOpWithTransposeConv#binaryOp : Pat< 123 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 124 (ConstantOp FloatElementsAttr:$bias), $padding, 125 $stride_h, $stride_w), 126 (ConstantOp FloatElementsAttr:$value), TFL_AF_None), 127 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 128 (binaryOp (ConstantOp $bias), 129 (ConstantOp $value), TFL_AF_None), 130 $padding, $stride_h, $stride_w), 131 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 132 (HasOneUse $output)]>; 133 // Fuse for TransposeConv with no bias 134 def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< 135 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 136 (ConstantOp $bias), $padding, 137 $stride_h, $stride_w), 138 (ConstantOp FloatElementsAttr:$value), TFL_AF_None), 139 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 140 (ConstantOp $value), 141 $padding, $stride_h, $stride_w), 142 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 143 (IsNoneType $bias), 144 (HasOneUse $output)]>; 145} 146foreach binaryOp = [TFL_AddOp, TFL_SubOp] in 147 defm : FuseBinaryOpToPrecedingAffine<binaryOp>; 148 149def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">; 150 151def ExpandTo4DForDepthwiseConv: NativeCodeCall< 152 "ExpandTo4DForDepthwiseConv($0)">; 153 154// If we see a (div or Mul) op (dividing/multiplying) a constant value 155// to a convolution op with constant filter and bias, we can fuse the div/mul 156// into the convolution op by constant folding 157// the filter/bias and the div/mul op's constant operand. 158// The following pattern restricts to float constant values for now. 159 160multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> { 161 def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat< 162 (BinaryOp (TFL_DepthwiseConv2DOp:$output $input, 163 (ConstantOp FloatElementsAttr:$filter), 164 (ConstantOp FloatElementsAttr:$bias), 165 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 166 $stride_w, $multiplier), 167 (ConstantOp FloatElementsAttr:$value), $act_fn), 168 (TFL_DepthwiseConv2DOp $input, 169 (BinaryOp 170 (ConstantOp $filter), 171 (ConstantOp (ExpandTo4DForDepthwiseConv $value)), 172 TFL_AF_None), 173 (BinaryOp 174 (ConstantOp $bias), 175 (ConstantOp $value), 176 TFL_AF_None), 177 $h_factor, $w_factor, $act_fn, $padding, $stride_h, 178 $stride_w, $multiplier), 179 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 180 (HasOneUse $output)]>; 181 def FuseMulOrDivWithConv#BinaryOp : Pat< 182 (BinaryOp (TFL_Conv2DOp:$conv_output $input, 183 (ConstantOp FloatElementsAttr:$filter), 184 (ConstantOp FloatElementsAttr:$bias), 185 $h_factor, $w_factor, TFL_AF_None, 186 $padding, $stride_h, $stride_w), 187 (ConstantOp FloatElementsAttr:$value), $act_fn), 188 (TFL_Conv2DOp $input, 189 (BinaryOp (ConstantOp $filter), 190 (ConstantOp (ExpandTo4DForConv $value)), 191 TFL_AF_None), 192 (BinaryOp (ConstantOp $bias), 193 (ConstantOp $value), 194 TFL_AF_None), 195 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 196 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 197 (HasOneUse $conv_output)]>; 198 def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< 199 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 200 (ConstantOp FloatElementsAttr:$weights), $input, 201 (ConstantOp FloatElementsAttr:$bias), 202 $padding, $stride_h, $stride_w), 203 (ConstantOp $value), TFL_AF_None), 204 (TFL_TransposeConvOp $output_shape, 205 (BinaryOp (ConstantOp $weights), 206 (ConstantOp (ExpandTo4DForConv $value)), 207 TFL_AF_None), 208 $input, 209 (BinaryOp (ConstantOp $bias), 210 (ConstantOp $value), 211 TFL_AF_None), 212 $padding, $stride_h, $stride_w), 213 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 214 (HasOneUse $output)]>; 215 def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< 216 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 217 (ConstantOp FloatElementsAttr:$weights), $input, 218 (ConstantOp $bias), 219 $padding, $stride_h, $stride_w), 220 (ConstantOp $value), TFL_AF_None), 221 (TFL_TransposeConvOp $output_shape, 222 (BinaryOp (ConstantOp $weights), 223 (ConstantOp (ExpandTo4DForConv $value)), 224 TFL_AF_None), 225 $input, 226 (ConstantOp $bias), 227 $padding, $stride_h, $stride_w), 228 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 229 (IsNoneType $bias), 230 (HasOneUse $output)]>; 231} 232 233foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in 234 defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>; 235 236 237// This pattern applies when the same quantize/dequantize have been used twice 238// with the same scale. We want to remove the redundancy. 239// TODO(fengliuai): move this to the sanity check of pre-quantize pass. 240def eliminate_dq_q_pairs : Pat< 241 (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), 242 (replaceWithValue $in), 243 [(NotFromQuantOpOrSameQuantType $in, $qt)]>; 244 245 246 247 248// Checks if the operand has rank == n 249class OperandHasRank<int n> : Constraint< 250 CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>; 251 252// Matching HardSwish 253def MatchHardSwishPattern1 : Pat< 254 (TFL_MulOp 255 (TFL_MulOp 256 $x, (TFL_AddOp 257 $x, 258 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 259 TFL_AF_Relu6), 260 TFL_AF_None), 261 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 262 TFL_AF_None), 263 (TFL_HardSwishOp $x)>; 264 265def MatchHardSwishPattern2 : Pat< 266 (TFL_MulOp 267 $x, 268 (TFL_MulOp 269 (TFL_AddOp 270 $x, 271 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 272 TFL_AF_Relu6), 273 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 274 TFL_AF_None), 275 TFL_AF_None), 276 (TFL_HardSwishOp $x)>; 277 278// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to 279// incorrect placement in the quantization aware training. 280// TODO(b/149735743): We should make the placement automatically. 281def MatchHardSwishQuantized : Pat< 282 (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp 283 (TFL_MulOp 284 $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp 285 $x, 286 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 287 TFL_AF_Relu6), $qattr2)), 288 TFL_AF_None), $qattr1)), 289 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 290 TFL_AF_None), 291 (TFL_HardSwishOp $x)>; 292 293// Constraint that the attribute value is less than 'n' 294class ConstDoubleValueLessThan<string n> : Constraint< 295 CPred<"$0.isa<DenseElementsAttr>() && " 296 "$0.cast<DenseElementsAttr>().getNumElements() == 1 && " 297 "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < " 298 # n>>; 299 300def L2NormValidReduceIndex : Constraint<CPred< 301 "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>; 302 303// Currently L2Normalization doesn't support activation function 304// in TFLite. 305// TODO(karimnosseir): Add constraints that the kernel code assumes. 306// constraint on axis and depth. 307multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> { 308 // This pattern constructs L2NormalizationOp from 309 // Mul->Rsqrt->Sum->Square Or 310 // Div->sqrt->Sum->Square 311 def L2NormalizePattern1#FirstOp#SecondOp : Pat< 312 (FirstOp $x, 313 (SecondOp 314 (TFL_SumOp 315 (TFL_SquareOp:$sq_op $x), 316 (ConstantOp I32ElementsAttr:$axis), 317 $keep_dims)), 318 TFL_AF_None), 319 (TFL_L2NormalizationOp $x, TFL_AF_None), 320 [(L2NormValidReduceIndex $sq_op, $axis)]>; 321 322 // Below patterns for L2Normalize when there is an Add or Maximum 323 // adding or clamping to a small constant scalar. 324 def L2NormalizePattern2#FirstOp#SecondOp : Pat< 325 (FirstOp $x, 326 (SecondOp 327 (TFL_AddOp 328 (TFL_SumOp 329 (TFL_SquareOp:$sq_op $x), 330 (ConstantOp I32ElementsAttr:$axis), 331 $keep_dims), 332 (ConstantOp $epsilon), TFL_AF_None)), 333 TFL_AF_None), 334 (TFL_L2NormalizationOp $x, TFL_AF_None), 335 [(L2NormValidReduceIndex $sq_op, $axis), 336 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 337 338 def L2NormalizePattern3#FirstOp#SecondOp : Pat< 339 (FirstOp $x, 340 (SecondOp 341 (TFL_MaximumOp 342 (TFL_SumOp 343 (TFL_SquareOp:$sq_op $x), 344 (ConstantOp I32ElementsAttr:$axis), 345 $keep_dims), 346 (ConstantOp $epsilon))), 347 TFL_AF_None), 348 (TFL_L2NormalizationOp $x, TFL_AF_None), 349 [(L2NormValidReduceIndex $sq_op, $axis), 350 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 351 352} 353 354foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] 355 in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>; 356 357//===----------------------------------------------------------------------===// 358// Binary ops patterns. 359//===----------------------------------------------------------------------===// 360def AreBroadcastableTypes : Constraint<CPred< 361 "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>; 362 363def OperandsBroadcastToOutputType : Constraint<CPred< 364 "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), " 365 "$2.getType())">>; 366 367def IsTailOfShape : Constraint<CPred< 368 "TFL::IsTailOfShape($0.getType(), $1.getType())">>; 369 370// Pattern for skipping Tile if it is mainly for broadcasting and the 371// Op is already supporting broadcasting. 372multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> { 373 def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat< 374 (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), 375 $operand, $act_func), 376 (BinaryOp $input, $operand, $act_func), 377 [(OperandsBroadcastToOutputType $input, $operand, $result), 378 (HasRankAtMost<4> $input), 379 (HasRankAtMost<4> $operand)]>; 380 381 def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat< 382 (BinaryOp:$result $operand, 383 (TFL_TileOp $input, (ConstantOp $tile)), $act_func), 384 (BinaryOp $operand, $input, $act_func), 385 [(OperandsBroadcastToOutputType $operand, $input, $result), 386 (HasRankAtMost<4> $operand), 387 (HasRankAtMost<4> $input)]>; 388} 389 390// Multi-pattern consisting of matching stand-alone op or op followed by relu. 391multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> { 392 foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 393 [TFL_Relu6Op, TFL_AF_Relu6], 394 [TFL_Relu1Op, TFL_AF_Relu1]] in { 395 def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat< 396 (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)), 397 (BinaryOp $lhs, $rhs, actFnPair[1]), 398 [(HasOneUse $binary_out)]>; 399 } 400} 401 402foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { 403 defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>; 404 405 // Instantiated FusedBinary patterns for the from-to pairs of ops. 406 defm : FusedBinaryActivationFuncOpPat<BinaryOp>; 407 408 // Move binary op before reshape: reshape -> binary => binary -> reshape. 409 // This is valid only when the binary operand is constant and the shape is the 410 // tail of the other operand and the intermediate result isn't used by other 411 // ops. 412 // $rhs is required to be the tail shape of $lhs, so after transformation the 413 // shape of the binary op result is valid. For example, assume the shapes of 414 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 415 // transformation, the shape of the binary op result is [40x1600], which 416 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 417 // make sure $rhs is the tail shape of $lhs. 418 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 419 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 420 (ConstantOp:$rhs $a), $act_fn), 421 (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape), 422 // The broadcasting of "BinaryOp" only happens in the lower 423 // dimensions, and the higher dimensions are same, so we know the 424 // result and input of the "BinaryOp" in the source pattern have 425 // the same shape, which is defined by `shape`. 426 [(IsTailOfShape $rhs, $lhs), 427 (HasOneUse $lhs), 428 // The result of the new "BinaryOp" will have the same shape as 429 // `input`. In other words, the shape of the `Reshape` op are not 430 // changed after the transformation. 431 (IsTailOfShape $rhs, $input), 432 (HasRankAtMost<4> $input), 433 (HasRankAtMost<4> $lhs), 434 (HasRankAtMost<4> $rhs)]>; 435} 436 437foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, 438 TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, 439 TFL_GreaterEqualOp] in { 440 // Move binary op before reshape: reshape -> binary => binary -> reshape. 441 // This is valid only when the binary operand is constant and the shape is the 442 // tail of the other operand and the intermediate result isn't used by other 443 // ops. 444 // $rhs is required to be the tail shape of $lhs, so after transformation the 445 // shape of the binary op result is valid. For example, assume the shapes of 446 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 447 // transformation, the shape of the binary op result is [40x1600], which 448 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 449 // make sure $rhs is the tail shape of $lhs. 450 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 451 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 452 (ConstantOp:$rhs $a)), 453 (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), 454 // The broadcasting of "BinaryOp" only happens in the lower 455 // dimensions, and the higher dimensions are same, so we know the 456 // result and input of the "BinaryOp" in the source pattern have 457 // the same shape, which is defined by `shape`. 458 [(IsTailOfShape $rhs, $lhs), 459 (HasOneUse $lhs), 460 // The result of the new "BinaryOp" will have the same shape as 461 // `input`. In other words, the shape of the `Reshape` op are not 462 // changed after the transformation. 463 (IsTailOfShape $rhs, $input), 464 (HasRankAtMost<4> $input), 465 (HasRankAtMost<4> $lhs), 466 (HasRankAtMost<4> $rhs)]>; 467} 468 469// Reorder the element-wise value operations and the element move operations, 470// such that the value operation happens before move operation. 471foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, 472 TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp, 473 TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in { 474 foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp, 475 TFL_ReshapeOp, TFL_TransposeOp] in { 476 def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat< 477 (ValueOp:$value (MoveOp:$move $input, $move_def)), 478 (MoveOp (ValueOp $input), $move_def), 479 [(HasOneUse $move)]>; 480 } 481} 482 483// Returns shape of a ranked tensor. 484// if called without a ranked tensor it will fail. 485def GetShape: NativeCodeCall<"GetShape($0)">; 486 487// Returns True if the operand type is RankedTensorType and valid. 488def HasValidRankedTensor : Constraint<CPred< 489 "$0.getType().isa<RankedTensorType>() && " 490 "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>; 491 492def ConvertSqueezeToReshape : Pat< 493 (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), 494 (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), 495 [(HasValidRankedTensor $squeeze_op)]>; 496 497// Convert expand_dims to reshape if possible. 498def ConvertExpandDimsToReshape : Pat< 499 (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), 500 (TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))), 501 [(AnyStaticShapeTensor $expand_dims_op)]>; 502 503class FloatValueEquals<string val> : Constraint<CPred< 504 "FloatValueEquals($0, " # val # ")">>; 505 506// ReLU patterns 507def MatchReluPattern : Pat< 508 (TFL_MaximumOp $input, (ConstantOp $Zero)), 509 (TFL_ReluOp $input), 510 [(FloatValueEquals<"0"> $Zero)]>; 511 512def MatchRelu1Pattern1 : Pat< 513 (TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)), 514 (ConstantOp $One)), 515 (TFL_Relu1Op $input), 516 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 517 518def MatchRelu1Pattern2 : Pat< 519 (TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)), 520 (ConstantOp $NegOne)), 521 (TFL_Relu1Op $input), 522 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 523 524def MatchLeakyRelu : Pat< 525 (TFL_MaximumOp 526 (TFL_MulOp:$mul_out $x, 527 (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), 528 $x), 529 (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha), 530 [(ConstDoubleValueLessThan<"1"> $alpha), 531 (HasOneUse $mul_out)]>; 532 533def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input), 534 (replaceWithValue $input), 535 [(SameElementType $input, $output)]>; 536 537// Checks if the operand0's rank is one less than operand1's rank. 538def PReluAlphaRankCheck : Constraint< 539 CPred<"$0.getType().cast<ShapedType>().getRank() == " 540 "$1.getType().cast<ShapedType>().getRank() - 1">>; 541 542// PReLU pattern from Keras: 543// f(x) = Relu(x) + (-alpha * Relu(-x)) 544def MatchPRelu : Pat< 545 (TFL_AddOp 546 (TFL_ReluOp:$relu_out $x), 547 (TFL_MulOp:$mul_out 548 (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)), 549 $neg_alpha, 550 TFL_AF_None), 551 TFL_AF_None), 552 (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)), 553 [(PReluAlphaRankCheck $neg_alpha, $x), 554 (HasOneUse $relu_out), 555 (HasOneUse $mul_out), 556 (HasOneUse $input_neg_out)]>; 557 558// The constant folding in this pass might produce constant in the tf dialect. 559// This rule is to legalize these constant to the tfl dialect. 560def LegalizeConstOp : Pat< 561 (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; 562 563// Reorders adds to allow constant folding. 564// Add --> Add $input, $constantA 565// \--> $constantB 566// To 567// Add --> $input 568// \--> Add ($constantA, $constantB) 569foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { 570 def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat< 571 (TFL_AddOp 572 (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), 573 (ConstantOp $b), ActFun), 574 (TFL_AddOp $input, 575 (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), 576 ActFun), 577 [(HasOneUse $first_output), 578 (HasRankAtMost<4> $input), 579 (HasRankAtMost<4> $a), 580 (HasRankAtMost<4> $b)]>; 581} 582 583// We can eliminate Relu from Relu(SquaredDifference(x, y)), 584// since the result of SquaredDifference is always non-negative. 585// TFLite interpreter doesn't support Relu+int32 for now. So the test cases 586// are failing without the following pattern to optimize Relu away fixes 587// the problem. 588def OptimizeReluSquaredDifference : Pat< 589 (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)), 590 (TFL_SquaredDifferenceOp $l, $r)>; 591 592// Optimize X^1 o X 593def OptimizePow1ToIdentity : Pat< 594 (TFL_PowOp $input, 595 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)), 596 (replaceWithValue $input)>; 597 598// Optimize X^2 to X*X 599def OptimizePow2ToSquare : Pat< 600 (TFL_PowOp $input, 601 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)), 602 (TFL_MulOp $input, $input, TFL_AF_None)>; 603 604def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred< 605 "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp(" 606 "$0, $1.cast<DenseIntElementsAttr>())">>; 607 608def OptimizeIdentityGatherNdOp : Pat< 609 (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)), 610 (replaceWithValue $params), 611 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; 612 613def OptimizeIdentityScatterNdOp : Pat< 614 (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored), 615 (replaceWithValue $params), 616 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; 617 618def ShapeMatchesReduceWithKeepAxes : Constraint<CPred< 619 "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>; 620 621// Fold reshapes re-inserting reduced dimensions into the results of a reduction 622// with `keep_dims=false` by changing it to one using `keep_dims=true`. 623foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp, 624 TFL_SumOp] in { 625 def FoldReshapeTo#ReduceOp : Pat< 626 (TFL_ReshapeOp 627 (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes), 628 ConstBoolAttrFalse), 629 (ConstantOp I32ElementsAttr: $shape)), 630 (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue), 631 [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape), 632 (HasOneUse $reduce)]>; 633} 634 635 636def IsSame : Constraint<CPred<"$0 == $1">>; 637def HasTwoUse : Constraint<CPred< 638 "std::distance($0.use_begin(), $0.use_end()) == 2">>; 639def AxesIsLastDimension : Constraint<CPred< 640 "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && " 641 "($0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == " 642 "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValue<int32_t>({0}) == -1)">>; 643 644// Convert exp(x)/sum(exp(x)) into softmax. 645def OptimizeToSoftmax : Pat< 646 (TFL_DivOp (TFL_ExpOp:$exp $input), 647 (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes), 648 ConstBoolAttrTrue), TFL_AF_None), 649 (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), 650 [(IsSame $exp, $sum_input), 651 (AxesIsLastDimension $axes, $sum_input), 652 (HasTwoUse $exp), 653 (HasOneUse $sum)]>; 654 655// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals 656// with the max normalization. 657def FoldNormalizationIntoSoftmax : Pat< 658 (TFL_SoftmaxOp 659 (TFL_SubOp:$sub $input, 660 (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes), 661 ConstBoolAttrTrue), 662 TFL_AF_None), 663 $beta), 664 (TFL_SoftmaxOp $input, $beta), 665 [(IsSame $input, $max_input), 666 (AxesIsLastDimension $axes, $max_input), 667 (HasOneUse $sub), 668 (HasOneUse $max)]>; 669 670def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>; 671 672class AllElementsAreF32<string val> : Constraint<CPred< 673 "($0.isa<DenseElementsAttr>() && " 674 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && " 675 "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), " 676 "$0.cast<DenseElementsAttr>().getValues<float>().end(), " 677 "[](float v){ return v == " #val# ";}))">>; 678 679// Optimize X*1 to X 680def OptimizeMul1ToIdentity : Pat< 681 (TFL_MulOp $input, 682 (ConstantOp $constant), 683 TFL_AF_None), 684 (replaceWithValue $input), 685 [(HaveSameType $input, $constant), 686 (AllElementsAreF32<"1.0f"> $constant)]>; 687 688class AllElementsAreBool<string val> : Constraint<CPred< 689 "($0.isa<DenseElementsAttr>() && " 690 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && " 691 "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), " 692 "$0.cast<DenseElementsAttr>().getValues<bool>().end(), " 693 "[](bool v){ return v == " #val# ";}))">>; 694 695// Remove select operators when the result is known in advance. 696foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { 697 // select(true_tensor, A, B) -> A 698 def Optimize#SelectOp#True : Pat< 699 (SelectOp (ConstantOp $constant), 700 $input1, 701 $input2), 702 (replaceWithValue $input1), 703 [(HaveSameType $input1, $input2), 704 (IsTailOfShape $input1, $constant), 705 (IsTailOfShape $constant, $input1), 706 (AllElementsAreBool<"true"> $constant)]>; 707 // select(false_tensor, A, B) -> B 708 def Optimize#SelectOp#False : Pat< 709 (SelectOp (ConstantOp $constant), 710 $input1, 711 $input2), 712 (replaceWithValue $input2), 713 [(HaveSameType $input1, $input2), 714 (IsTailOfShape $input1, $constant), 715 (IsTailOfShape $constant, $input1), 716 (AllElementsAreBool<"false"> $constant)]>; 717} 718 719// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic. 720foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in { 721 def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat< 722 (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta), 723 (ConstantOp:$const_axes I32ElementsAttr:$axes)), 724 (ArgMinMaxOp $logits, $const_axes), 725 [(HasOneUse $softmax), 726 (AxesIsLastDimension $axes, $logits)]>; 727 728 def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat< 729 (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits), 730 (ConstantOp:$const_axes I32ElementsAttr:$axes)), 731 (ArgMinMaxOp $logits, $const_axes), 732 [(HasOneUse $log_softmax), 733 (AxesIsLastDimension $axes, $logits)]>; 734} 735