1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the optimization pattern definition file for TensorFlow Lite. 17 18include "mlir/IR/OpBase.td" 19include "mlir/Dialect/StandardOps/IR/Ops.td" 20include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 21include "tensorflow/compiler/mlir/lite/utils/utils.td" 22include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 23 24// Checks if the param passed is a F32 ElementsAttr. 25def F32ElementsAttr : ElementsAttrBase< 26 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">, 27 "32 bit float constant tensor">; 28 29// Checks if the param passed is a float ElementsAttr. 30def FloatElementsAttr : ElementsAttrBase< 31 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">, 32 "float constant tensor">; 33 34// Checks if the param passed is of NoneType. 35def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>; 36 37def ExtractSingleElementAsFloat : NativeCodeCall< 38 "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">; 39 40// Checks if the value has rank at most 'n'. 41class HasRankAtMost<int n> : Constraint< 42 CPred<"$0.getType().cast<ShapedType>().hasRank() && " 43 "$0.getType().cast<ShapedType>().getRank() <= " # n>>; 44 45// Checks if the value has rank 'n'. 46class HasRank<int n> : Constraint< 47 CPred<"$0.getType().cast<ShapedType>().hasRank() && " 48 "$0.getType().cast<ShapedType>().getRank() == " # n>>; 49 50//===----------------------------------------------------------------------===// 51// Ternary ops patterns. 52//===----------------------------------------------------------------------===// 53// Multi-pattern consisting of matching stand-alone convolution op followed by 54// activation op. 55multiclass FuseActFnIntoConvOpPat<Op ActFnOp, Attr ActFnAttr> { 56 def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat< 57 (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor, 58 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)), 59 (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr, 60 $padding, $stride_h, $stride_w), 61 [(HasOneUse $conv_out)]>; 62 def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat< 63 (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor, 64 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, 65 $multiplier)), 66 (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor, 67 ActFnAttr, $padding, $stride_h, $stride_w, $multiplier), 68 [(HasOneUse $conv_out)]>; 69} 70 71multiclass FuseActFnIntoPoolOpPat<Op ActFnOp, Attr ActFnAttr> { 72 def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat< 73 (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height, 74 $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)), 75 (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding, 76 $stride_h, $stride_w, ActFnAttr), 77 [(HasOneUse $pool_out)]>; 78 def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat< 79 (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h, 80 $filter_width, $filter_height, TFL_AF_None)), 81 (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h, 82 $filter_width, $filter_height, ActFnAttr), 83 [(HasOneUse $pool_out)]>; 84} 85 86// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused 87// activation functions. 88// Currently we're not fusing tanh, sigmoid, hard_swish and other activations 89// those cannot be simply translated into clamping. 90foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 91 [TFL_Relu6Op, TFL_AF_Relu6], 92 [TFL_Relu1Op, TFL_AF_Relu1]] in { 93 defm : FuseActFnIntoConvOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>; 94 defm : FuseActFnIntoPoolOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>; 95} 96 97class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint< 98 CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; 99 100// If we see a binary op (add, sub) op adding a constant value to a convolution 101// op with constant bias, we can fuse the binary op into the convolution op by 102// constant folding the bias and the binary op's constant operand. The following 103// pattern restricts to float constant values for now. 104multiclass FuseBinaryOpToPrecedingAffine<Op binaryOp> { 105 def FuseBinaryOpWithConv#binaryOp : Pat< 106 (binaryOp (TFL_Conv2DOp:$output $input, $filter, 107 (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, 108 TFL_AF_None, $padding, $stride_h, $stride_w), 109 (ConstantOp FloatElementsAttr:$value), $act_fn), 110 (TFL_Conv2DOp $input, $filter, 111 (binaryOp (ConstantOp $bias), 112 (ConstantOp $value), TFL_AF_None), 113 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 114 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 115 (HasOneUse $output)]>; 116 def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat< 117 (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, 118 (ConstantOp FloatElementsAttr:$bias), 119 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 120 $stride_w, $multiplier), 121 (ConstantOp FloatElementsAttr:$value), $act_fn), 122 (TFL_DepthwiseConv2DOp $input, $filter, 123 (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), 124 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, 125 $multiplier), 126 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 127 (HasRank<1> $value), 128 (HasOneUse $output)]>; 129 def FuseBinaryOpWithTransposeConv#binaryOp : Pat< 130 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 131 (ConstantOp FloatElementsAttr:$bias), $padding, 132 $stride_h, $stride_w), 133 (ConstantOp FloatElementsAttr:$value), TFL_AF_None), 134 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 135 (binaryOp (ConstantOp $bias), 136 (ConstantOp $value), TFL_AF_None), 137 $padding, $stride_h, $stride_w), 138 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 139 (HasOneUse $output)]>; 140 // Fuse for TransposeConv with no bias 141 def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< 142 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 143 (ConstantOp $bias), $padding, 144 $stride_h, $stride_w), 145 (ConstantOp FloatElementsAttr:$value), TFL_AF_None), 146 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 147 (ConstantOp $value), 148 $padding, $stride_h, $stride_w), 149 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 150 (IsNoneType $bias), 151 (HasOneUse $output)]>; 152} 153foreach binaryOp = [TFL_AddOp, TFL_SubOp]<Op> in 154 defm : FuseBinaryOpToPrecedingAffine<binaryOp>; 155 156def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">; 157 158def ExpandTo4DForDepthwiseConv: NativeCodeCall< 159 "ExpandTo4DForDepthwiseConv($0)">; 160 161// If we see a (div or Mul) op (dividing/multiplying) a constant value 162// to a convolution op with constant filter and bias, we can fuse the div/mul 163// into the convolution op by constant folding 164// the filter/bias and the div/mul op's constant operand. 165// The following pattern restricts to float constant values for now. 166 167multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<Op BinaryOp> { 168 def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat< 169 (BinaryOp (TFL_DepthwiseConv2DOp:$output $input, 170 (ConstantOp FloatElementsAttr:$filter), 171 (ConstantOp FloatElementsAttr:$bias), 172 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 173 $stride_w, $multiplier), 174 (ConstantOp FloatElementsAttr:$value), $act_fn), 175 (TFL_DepthwiseConv2DOp $input, 176 (BinaryOp 177 (ConstantOp $filter), 178 (ConstantOp (ExpandTo4DForDepthwiseConv $value)), 179 TFL_AF_None), 180 (BinaryOp 181 (ConstantOp $bias), 182 (ConstantOp $value), 183 TFL_AF_None), 184 $h_factor, $w_factor, $act_fn, $padding, $stride_h, 185 $stride_w, $multiplier), 186 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 187 (HasRank<1> $value), 188 (HasOneUse $output)]>; 189 def FuseMulOrDivWithConv#BinaryOp : Pat< 190 (BinaryOp (TFL_Conv2DOp:$conv_output $input, 191 (ConstantOp FloatElementsAttr:$filter), 192 (ConstantOp FloatElementsAttr:$bias), 193 $h_factor, $w_factor, TFL_AF_None, 194 $padding, $stride_h, $stride_w), 195 (ConstantOp FloatElementsAttr:$value), $act_fn), 196 (TFL_Conv2DOp $input, 197 (BinaryOp (ConstantOp $filter), 198 (ConstantOp (ExpandTo4DForConv $value)), 199 TFL_AF_None), 200 (BinaryOp (ConstantOp $bias), 201 (ConstantOp $value), 202 TFL_AF_None), 203 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 204 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 205 (HasOneUse $conv_output)]>; 206 def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< 207 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 208 (ConstantOp FloatElementsAttr:$weights), $input, 209 (ConstantOp FloatElementsAttr:$bias), 210 $padding, $stride_h, $stride_w), 211 (ConstantOp $value), TFL_AF_None), 212 (TFL_TransposeConvOp $output_shape, 213 (BinaryOp (ConstantOp $weights), 214 (ConstantOp (ExpandTo4DForConv $value)), 215 TFL_AF_None), 216 $input, 217 (BinaryOp (ConstantOp $bias), 218 (ConstantOp $value), 219 TFL_AF_None), 220 $padding, $stride_h, $stride_w), 221 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 222 (HasOneUse $output)]>; 223 def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< 224 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 225 (ConstantOp FloatElementsAttr:$weights), $input, 226 (ConstantOp $bias), 227 $padding, $stride_h, $stride_w), 228 (ConstantOp $value), TFL_AF_None), 229 (TFL_TransposeConvOp $output_shape, 230 (BinaryOp (ConstantOp $weights), 231 (ConstantOp (ExpandTo4DForConv $value)), 232 TFL_AF_None), 233 $input, 234 (ConstantOp $bias), 235 $padding, $stride_h, $stride_w), 236 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 237 (IsNoneType $bias), 238 (HasOneUse $output)]>; 239} 240 241foreach BinaryOp = [TFL_DivOp, TFL_MulOp]<Op> in 242 defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>; 243 244 245// This pattern applies when the same quantize/dequantize have been used twice 246// with the same scale. We want to remove the redundancy. 247// TODO(fengliuai): move this to the sanity check of pre-quantize pass. 248def eliminate_dq_q_pairs : Pat< 249 (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), 250 (replaceWithValue $in), 251 [(NotFromQuantOpOrSameQuantType $in, $qt)]>; 252 253// Matching HardSwish 254def MatchHardSwishPattern1 : Pat< 255 (TFL_MulOp 256 (TFL_MulOp 257 $x, (TFL_AddOp 258 $x, 259 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 260 TFL_AF_Relu6), 261 TFL_AF_None), 262 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 263 TFL_AF_None), 264 (TFL_HardSwishOp $x)>; 265 266def MatchHardSwishPattern2 : Pat< 267 (TFL_MulOp 268 $x, 269 (TFL_MulOp 270 (TFL_AddOp 271 $x, 272 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 273 TFL_AF_Relu6), 274 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 275 TFL_AF_None), 276 TFL_AF_None), 277 (TFL_HardSwishOp $x)>; 278 279def MatchHardSwishPattern3 : Pat< 280 (TFL_MulOp 281 (TFL_MulOp 282 $x, 283 (TFL_AddOp 284 $x, 285 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 286 TFL_AF_Relu6), 287 TFL_AF_None), 288 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 289 TFL_AF_None), 290 (TFL_HardSwishOp $x)>; 291 292def MatchHardSwishPattern4 : Pat< 293 (TFL_MulOp 294 (TFL_MulOp 295 (TFL_AddOp 296 $x, 297 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 298 TFL_AF_Relu6), 299 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 300 TFL_AF_None), 301 $x, 302 TFL_AF_None), 303 (TFL_HardSwishOp $x)>; 304 305// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to 306// incorrect placement in the quantization aware training. 307def MatchHardSwishQuantized : Pat< 308 (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp 309 (TFL_MulOp 310 $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp 311 $x, 312 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 313 TFL_AF_Relu6), $qattr2)), 314 TFL_AF_None), $qattr1)), 315 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 316 TFL_AF_None), 317 (TFL_HardSwishOp $x)>; 318 319// Constraint that the attribute value is less than 'n' 320class ConstDoubleValueLessThan<string n> : Constraint< 321 CPred<"$0.isa<DenseElementsAttr>() && " 322 "$0.cast<DenseElementsAttr>().getNumElements() == 1 && " 323 "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < " 324 # n>>; 325 326def L2NormValidReduceIndex : Constraint<CPred< 327 "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>; 328 329// Currently L2Normalization doesn't support activation function 330// in TFLite. 331// TODO(karimnosseir): Add constraints that the kernel code assumes. 332// constraint on axis and depth. 333multiclass L2NormalizePatterns<Op FirstOp, Op SecondOp> { 334 // This pattern constructs L2NormalizationOp from 335 // Mul->Rsqrt->Sum->Square Or 336 // Div->sqrt->Sum->Square 337 def L2NormalizePattern1#FirstOp#SecondOp : Pat< 338 (FirstOp $x, 339 (SecondOp 340 (TFL_SumOp 341 (TFL_SquareOp:$sq_op $x), 342 (ConstantOp I32ElementsAttr:$axis), 343 $keep_dims)), 344 TFL_AF_None), 345 (TFL_L2NormalizationOp $x, TFL_AF_None), 346 [(L2NormValidReduceIndex $sq_op, $axis)]>; 347 348 // Below patterns for L2Normalize when there is an Add or Maximum 349 // adding or clamping to a small constant scalar. 350 def L2NormalizePattern2#FirstOp#SecondOp : Pat< 351 (FirstOp $x, 352 (SecondOp 353 (TFL_AddOp 354 (TFL_SumOp 355 (TFL_SquareOp:$sq_op $x), 356 (ConstantOp I32ElementsAttr:$axis), 357 $keep_dims), 358 (ConstantOp $epsilon), TFL_AF_None)), 359 TFL_AF_None), 360 (TFL_L2NormalizationOp $x, TFL_AF_None), 361 [(L2NormValidReduceIndex $sq_op, $axis), 362 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 363 364 def L2NormalizePattern3#FirstOp#SecondOp : Pat< 365 (FirstOp $x, 366 (SecondOp 367 (TFL_MaximumOp 368 (TFL_SumOp 369 (TFL_SquareOp:$sq_op $x), 370 (ConstantOp I32ElementsAttr:$axis), 371 $keep_dims), 372 (ConstantOp $epsilon))), 373 TFL_AF_None), 374 (TFL_L2NormalizationOp $x, TFL_AF_None), 375 [(L2NormValidReduceIndex $sq_op, $axis), 376 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 377 378} 379 380foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] 381 in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>; 382 383//===----------------------------------------------------------------------===// 384// Binary ops patterns. 385//===----------------------------------------------------------------------===// 386def AreBroadcastableTypes : Constraint<CPred< 387 "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>; 388 389def OperandsBroadcastToOutputType : Constraint<CPred< 390 "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), " 391 "$2.getType())">>; 392 393def IsTailOfShape : Constraint<CPred< 394 "TFL::IsTailOfShape($0.getType(), $1.getType())">>; 395 396def Flatten : NativeCodeCall< 397 "$0.cast<DenseElementsAttr>()" 398 ".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, " 399 "$0.getType().cast<ShapedType>().getElementType()))">; 400 401def IsLastDimEqualToNumElements : Constraint<CPred< 402 "$0.getType().cast<ShapedType>().getRank() >= 1 && " 403 "$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == " 404 "$1.getType().cast<ShapedType>().getNumElements()">>; 405 406def IsDefinedByFullyConnectedOp : Constraint<CPred< 407 "$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>; 408 409// Pattern for skipping Tile if it is mainly for broadcasting and the 410// Op is already supporting broadcasting. 411multiclass FuseTileBroadcastIntoFollowingBinary<Op BinaryOp> { 412 def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat< 413 (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), 414 $operand, $act_func), 415 (BinaryOp $input, $operand, $act_func), 416 [(OperandsBroadcastToOutputType $input, $operand, $result), 417 (HasRankAtMost<4> $input), 418 (HasRankAtMost<4> $operand)]>; 419 420 def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat< 421 (BinaryOp:$result $operand, 422 (TFL_TileOp $input, (ConstantOp $tile)), $act_func), 423 (BinaryOp $operand, $input, $act_func), 424 [(OperandsBroadcastToOutputType $operand, $input, $result), 425 (HasRankAtMost<4> $operand), 426 (HasRankAtMost<4> $input)]>; 427} 428 429// Multi-pattern consisting of matching stand-alone op or op followed by relu. 430multiclass FusedBinaryActivationFuncOpPat<Op BinaryOp> { 431 foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 432 [TFL_Relu6Op, TFL_AF_Relu6], 433 [TFL_Relu1Op, TFL_AF_Relu1]] in { 434 def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat< 435 (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)), 436 (BinaryOp $lhs, $rhs, actFnPair[1]), 437 [(HasOneUse $binary_out)]>; 438 } 439} 440 441foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { 442 defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>; 443 444 // Instantiated FusedBinary patterns for the from-to pairs of ops. 445 defm : FusedBinaryActivationFuncOpPat<BinaryOp>; 446 447 // Move binary op before reshape: reshape -> binary => binary -> reshape. 448 // This is valid only when the binary operand is constant and the shape is the 449 // tail of the other operand and the intermediate result isn't used by other 450 // ops. 451 // $rhs is required to be the tail shape of $lhs, so after transformation the 452 // shape of the binary op result is valid. For example, assume the shapes of 453 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 454 // transformation, the shape of the binary op result is [40x1600], which 455 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 456 // make sure $rhs is the tail shape of $lhs. 457 def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat< 458 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 459 (ConstantOp:$rhs $a), $act_fn), 460 (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape), 461 // The broadcasting of "BinaryOp" only happens in the lower 462 // dimensions, and the higher dimensions are same, so we know the 463 // result and input of the "BinaryOp" in the source pattern have 464 // the same shape, which is defined by `shape`. 465 [(IsTailOfShape $rhs, $lhs), 466 (HasOneUse $lhs), 467 // The result of the new "BinaryOp" will have the same shape as 468 // `input`. In other words, the shape of the `Reshape` op are not 469 // changed after the transformation. 470 (IsTailOfShape $rhs, $input), 471 (HasRankAtMost<4> $input), 472 (HasRankAtMost<4> $lhs), 473 (HasRankAtMost<4> $rhs)]>; 474 475 // Move binary op before reshape: 476 // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) 477 // This is valid only when both side of the binary operand is reshaped, and 478 // the sizes are the same both before and after the reshape. 479 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 480 (BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)), 481 (TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2)), 482 $act_fn), 483 (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1), 484 [(IsTailOfShape $rhs, $lhs), 485 (IsTailOfShape $lhs, $rhs), 486 (IsTailOfShape $input1, $input2), 487 (IsTailOfShape $input2, $input1)]>; 488 489 // Move binary op before reshape: 490 // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) 491 // This is valid only when the last dimension of lhs is equal to the 492 // number of elements in constant rhs. 493 // Therefore, after transformation broadcast of binary op is always 494 // applied to the last dimension of $input. 495 def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< 496 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 497 (ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn), 498 (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)), 499 $act_fn), 500 $shape), 501 [(AnyStaticShapeTensor $input), 502 (IsTailOfShape $rhs, $lhs), 503 (IsLastDimEqualToNumElements $input, $rhs), 504 (HasOneUse $lhs), 505 // Restrict operands to have at most rank 4 because TFLite binary 506 // kernel supports up to 4D broadcast. 507 (HasRankAtMost<4> $input), 508 (HasRankAtMost<4> $lhs), 509 (HasRankAtMost<4> $rhs), 510 (IsDefinedByFullyConnectedOp $input)]>; 511} 512 513foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, 514 TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, 515 TFL_GreaterEqualOp] in { 516 // Move binary op before reshape: reshape -> binary => binary -> reshape. 517 // This is valid only when the binary operand is constant and the shape is the 518 // tail of the other operand and the intermediate result isn't used by other 519 // ops. 520 // $rhs is required to be the tail shape of $lhs, so after transformation the 521 // shape of the binary op result is valid. For example, assume the shapes of 522 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 523 // transformation, the shape of the binary op result is [40x1600], which 524 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 525 // make sure $rhs is the tail shape of $lhs. 526 def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat< 527 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 528 (ConstantOp:$rhs $a)), 529 (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), 530 // The broadcasting of "BinaryOp" only happens in the lower 531 // dimensions, and the higher dimensions are same, so we know the 532 // result and input of the "BinaryOp" in the source pattern have 533 // the same shape, which is defined by `shape`. 534 [(IsTailOfShape $rhs, $lhs), 535 (HasOneUse $lhs), 536 // The result of the new "BinaryOp" will have the same shape as 537 // `input`. In other words, the shape of the `Reshape` op are not 538 // changed after the transformation. 539 (IsTailOfShape $rhs, $input), 540 (HasRankAtMost<4> $input), 541 (HasRankAtMost<4> $lhs), 542 (HasRankAtMost<4> $rhs)]>; 543 544 // Move binary op before reshape: 545 // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) 546 // This is valid only when both side of the binary operand is reshaped, and 547 // the sizes are the same both before and after the reshape. 548 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 549 (BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)), 550 (TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2))), 551 (TFL_ReshapeOp (BinaryOp $input1, $input2), $shape1), 552 [(IsTailOfShape $rhs, $lhs), 553 (IsTailOfShape $lhs, $rhs), 554 (IsTailOfShape $input1, $input2), 555 (IsTailOfShape $input2, $input1)]>; 556 557 // Move binary op before reshape: 558 // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) 559 // This is valid only when the last dimension of lhs is equal to the 560 // number of elements in constant rhs. 561 // Therefore, after transformation broadcast of binary op is always 562 // applied to the last dimension of $input. 563 def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< 564 (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), 565 (ConstantOp:$rhs ElementsAttr:$rhs_attr)), 566 (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))), 567 $shape), 568 [(AnyStaticShapeTensor $input), 569 (IsTailOfShape $rhs, $lhs), 570 (IsLastDimEqualToNumElements $input, $rhs), 571 (HasOneUse $lhs), 572 // Restrict operands to have at most rank 4 because TFLite binary 573 // kernel supports up to 4D broadcast. 574 (HasRankAtMost<4> $input), 575 (HasRankAtMost<4> $lhs), 576 (HasRankAtMost<4> $rhs), 577 (IsDefinedByFullyConnectedOp $input)]>; 578} 579 580// Reorder the element-wise value operations and the element move operations, 581// such that the value operation happens before move operation. 582foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, 583 TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp, 584 TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp, TFL_LogisticOp] in { 585 foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp, 586 TFL_ReshapeOp, TFL_TransposeOp] in { 587 def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat< 588 (ValueOp:$value (MoveOp:$move $input, $move_def)), 589 (MoveOp (ValueOp $input), $move_def), 590 [(SameElementType $input, $value), (HasOneUse $move)]>; 591 } 592} 593 594// Returns shape of a ranked tensor. 595// if called without a ranked tensor it will fail. 596def GetShape: NativeCodeCall<"GetShape($0)">; 597 598// Returns True if the operand type is RankedTensorType and valid. 599def HasValidRankedTensor : Constraint<CPred< 600 "$0.getType().isa<RankedTensorType>() && " 601 "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>; 602 603def ConvertSqueezeToReshape : Pat< 604 (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), 605 (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), 606 [(HasValidRankedTensor $squeeze_op)]>; 607 608// Convert expand_dims to reshape if possible. 609def ConvertExpandDimsToReshape : Pat< 610 (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), 611 (TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))), 612 [(AnyStaticShapeTensor $expand_dims_op)]>; 613 614class FloatValueEquals<string val> : Constraint<CPred< 615 "FloatValueEquals($0, " # val # ")">>; 616 617// ReLU patterns 618def MatchReluPattern : Pat< 619 (TFL_MaximumOp $input, (ConstantOp $Zero)), 620 (TFL_ReluOp $input), 621 [(FloatValueEquals<"0"> $Zero)]>; 622 623def MatchRelu1Pattern1 : Pat< 624 (TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)), 625 (ConstantOp $One)), 626 (TFL_Relu1Op $input), 627 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 628 629def MatchRelu1Pattern2 : Pat< 630 (TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)), 631 (ConstantOp $NegOne)), 632 (TFL_Relu1Op $input), 633 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 634 635def MatchLeakyRelu : Pat< 636 (TFL_MaximumOp 637 (TFL_MulOp:$mul_out $x, 638 (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), 639 $x), 640 (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha), 641 [(ConstDoubleValueLessThan<"1"> $alpha), 642 (HasOneUse $mul_out)]>; 643 644// Returns True if all users of this operation are in TF/TFL and don't need 645// shape exact matching. This prevents from removing cast on return values which 646// can break the verifier on function type mismatch. 647def AllUsersInTF : Constraint<CPred<[{ 648 llvm::all_of($0.getUsers(), [&](Operation *user) { 649 auto name = user->getName().getDialectNamespace(); 650 return name == "tf" || name == "tfl"; 651 }) 652 }]>, "all users are TF/TFL operations.">; 653 654def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input), 655 (replaceWithValue $input), 656 [(SameElementType $input, $output), 657 (AllUsersInTF $output)]>; 658 659 660// Checks if the operand0's rank is one less than operand1's rank. 661def PReluAlphaRankCheck : Constraint< 662 CPred<"$0.getType().cast<ShapedType>().getRank() == " 663 "$1.getType().cast<ShapedType>().getRank() - 1">>; 664 665// PReLU pattern from Keras: 666// f(x) = Relu(x) + (-alpha * Relu(-x)) 667def MatchPRelu : Pat< 668 (TFL_AddOp 669 (TFL_ReluOp:$relu_out $x), 670 (TFL_MulOp:$mul_out 671 (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)), 672 $neg_alpha, 673 TFL_AF_None), 674 TFL_AF_None), 675 (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)), 676 [(PReluAlphaRankCheck $neg_alpha, $x), 677 (HasOneUse $relu_out), 678 (HasOneUse $mul_out), 679 (HasOneUse $input_neg_out)]>; 680 681// The constant folding in this pass might produce constant in the tf dialect. 682// This rule is to legalize these constant to the tfl dialect. 683def LegalizeConstOp : Pat< 684 (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; 685 686// Reorders adds to allow constant folding. 687// Add --> Add $input, $constantA 688// \--> $constantB 689// To 690// Add --> $input 691// \--> Add ($constantA, $constantB) 692foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { 693 def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat< 694 (TFL_AddOp 695 (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), 696 (ConstantOp $b), ActFun), 697 (TFL_AddOp $input, 698 (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), 699 ActFun), 700 [(HasOneUse $first_output), 701 (HasRankAtMost<4> $input), 702 (HasRankAtMost<4> $a), 703 (HasRankAtMost<4> $b)]>; 704} 705 706// We can eliminate Relu from Relu(SquaredDifference(x, y)), 707// since the result of SquaredDifference is always non-negative. 708// TFLite interpreter doesn't support Relu+int32 for now. So the test cases 709// are failing without the following pattern to optimize Relu away fixes 710// the problem. 711def OptimizeReluSquaredDifference : Pat< 712 (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)), 713 (TFL_SquaredDifferenceOp $l, $r)>; 714 715// Optimize X^1 o X 716def OptimizePow1ToIdentity : Pat< 717 (TFL_PowOp $input, 718 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)), 719 (replaceWithValue $input)>; 720 721// Optimize X^2 to X*X 722def OptimizePow2ToSquare : Pat< 723 (TFL_PowOp $input, 724 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)), 725 (TFL_MulOp $input, $input, TFL_AF_None)>; 726 727// Optimize X^(1/2) to √X 728def OptimizePow2ToSqrt : Pat< 729 (TFL_PowOp $input, 730 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.5f">)), 731 (TFL_SqrtOp $input)>; 732 733// Optimize X^(-1/2) to 1/√X == rsqrt(x) 734def OptimizePow2ToRsqrt : Pat< 735 (TFL_PowOp $input, 736 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "-0.5f">)), 737 (TFL_RsqrtOp $input)>; 738 739def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred< 740 "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp(" 741 "$0, $1.cast<DenseIntElementsAttr>())">>; 742 743def OptimizeIdentityGatherNdOp : Pat< 744 (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)), 745 (replaceWithValue $params), 746 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; 747 748def OptimizeIdentityScatterNdOp : Pat< 749 (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored), 750 (replaceWithValue $params), 751 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; 752 753def ShapeMatchesReduceWithKeepAxes : Constraint<CPred< 754 "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>; 755 756// Fold reshapes re-inserting reduced dimensions into the results of a reduction 757// with `keep_dims=false` by changing it to one using `keep_dims=true`. 758foreach ReduceOp = [TFL_MeanOp, TFL_ReduceMaxOp, TFL_ReduceMinOp, 759 TFL_ReduceProdOp, TFL_SumOp] in { 760 def FoldReshapeTo#ReduceOp : Pat< 761 (TFL_ReshapeOp 762 (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes), 763 ConstBoolAttrFalse), 764 (ConstantOp I32ElementsAttr: $shape)), 765 (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue), 766 [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape), 767 (HasOneUse $reduce)]>; 768} 769 770 771def IsSame : Constraint<CPred<"$0 == $1">>; 772def HasTwoUse : Constraint<CPred< 773 "std::distance($0.use_begin(), $0.use_end()) == 2">>; 774def AxesIsLastDimension : Constraint<CPred< 775 "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && " 776 "($0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == " 777 "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValue<int32_t>({0}) == -1)">>; 778 779// Convert exp(x)/sum(exp(x)) into softmax. 780def OptimizeToSoftmax : Pat< 781 (TFL_DivOp (TFL_ExpOp:$exp $input), 782 (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes), 783 ConstBoolAttrTrue), TFL_AF_None), 784 (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), 785 [(IsSame $exp, $sum_input), 786 (AxesIsLastDimension $axes, $sum_input), 787 (HasTwoUse $exp), 788 (HasOneUse $sum)]>; 789 790// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals 791// with the max normalization. 792def FoldNormalizationIntoSoftmax : Pat< 793 (TFL_SoftmaxOp 794 (TFL_SubOp:$sub $input, 795 (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes), 796 ConstBoolAttrTrue), 797 TFL_AF_None), 798 $beta), 799 (TFL_SoftmaxOp $input, $beta), 800 [(IsSame $input, $max_input), 801 (AxesIsLastDimension $axes, $max_input), 802 (HasOneUse $sub), 803 (HasOneUse $max)]>; 804 805def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>; 806 807class AllElementsAreF32<string val> : Constraint<CPred< 808 "($0.isa<DenseElementsAttr>() && " 809 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && " 810 "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), " 811 "$0.cast<DenseElementsAttr>().getValues<float>().end(), " 812 "[](float v){ return v == " #val# ";}))">>; 813 814// Optimize X*1 to X 815def OptimizeMul1ToIdentity : Pat< 816 (TFL_MulOp:$result $input, 817 (ConstantOp $constant), 818 TFL_AF_None), 819 (replaceWithValue $input), 820 [(HaveSameType $input, $result), 821 (AllElementsAreF32<"1.0f"> $constant)]>; 822 823class AllElementsAreBool<string val> : Constraint<CPred< 824 "($0.isa<DenseElementsAttr>() && " 825 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && " 826 "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), " 827 "$0.cast<DenseElementsAttr>().getValues<bool>().end(), " 828 "[](bool v){ return v == " #val# ";}))">>; 829 830// Remove select operators when the result is known in advance. 831foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { 832 // select(true_tensor, A, B) -> A 833 def Optimize#SelectOp#True : Pat< 834 (SelectOp:$result (ConstantOp $constant), 835 $input1, 836 $input2), 837 (replaceWithValue $input1), 838 [(HaveSameType $input1, $result), 839 (AllElementsAreBool<"true"> $constant)]>; 840 // select(false_tensor, A, B) -> B 841 def Optimize#SelectOp#False : Pat< 842 (SelectOp:$result (ConstantOp $constant), 843 $input1, 844 $input2), 845 (replaceWithValue $input2), 846 [(HaveSameType $input2, $result), 847 (AllElementsAreBool<"false"> $constant)]>; 848} 849 850// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic. 851foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in { 852 def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat< 853 (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta), 854 (ConstantOp:$const_axes I32ElementsAttr:$axes)), 855 (ArgMinMaxOp $logits, $const_axes), 856 [(HasOneUse $softmax), 857 (AxesIsLastDimension $axes, $logits)]>; 858 859 def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat< 860 (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits), 861 (ConstantOp:$const_axes I32ElementsAttr:$axes)), 862 (ArgMinMaxOp $logits, $const_axes), 863 [(HasOneUse $log_softmax), 864 (AxesIsLastDimension $axes, $logits)]>; 865} 866 867def CanOptimizeIdentitySliceOp : Constraint<CPred< 868 "TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>; 869 870// Remove Slice ops slicing the whole input tensor, effectively no-op. 871def OptimizeSliceOp : Pat< 872 (TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)), 873 (replaceWithValue $input), 874 [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; 875 876def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">; 877 878def IsLastElementEqualsOne : Constraint<CPred< 879 "TFL::IsLastElementEqualsOne($0)">>; 880 881def IsOneHotIndexAttribute : Constraint<CPred< 882 "TFL::IsOneHotIndexAttribute($0)">>; 883 884// Replace 885// Equal(Reshape(X, shape), indices) 886// With 887// OneHot(X, N, true, false, -1) 888// where 889// - last value in shape is 1 890// - indices is a incrementing series from 0 to N-1. (N elements total.) 891def ReshapeEqualOpToOneHotOp : Pat< 892 (TFL_EqualOp (TFL_ReshapeOp $x, (ConstantOp $shape)), 893 (ConstantOp $series)), 894 (TFL_OneHotOp $x, 895 (ConstantOp (GetNumElementsOrOne $series)), 896 (ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "true">), 897 (ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "false">), 898 ConstantAttr<I32Attr, "-1">), 899 [(IsLastElementEqualsOne $shape), 900 (IsOneHotIndexAttribute $series)]>; 901 902def F32ElementsVal : Constraint<CPred< 903 "$0.getType().cast<TensorType>().getElementType().isF32()">, 904 "32 bit float tensor">; 905def I32ElementsVal : Constraint<CPred< 906 "$0.getType().cast<TensorType>().getElementType().isInteger(32)">, 907 "32 bit integer tensor">; 908 909def ConvertSingleElementAttrToFloatAttr : 910 NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">; 911 912// Replace 913// (float)OneHot(index, depth, on_val, off_val, axis) 914// With 915// OneHot(index, depth, (float)on_val, (float)off_val, axis) 916def FuseOneHotAndCastToFloat : Pat< 917 (TFL_CastOp:$output (TFL_OneHotOp $indices, 918 $depth, 919 (ConstantOp $on_val), 920 (ConstantOp $off_val), 921 $axis)), 922 (TFL_OneHotOp $indices, 923 $depth, 924 (ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)), 925 (ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)), 926 $axis), 927 [(F32ElementsVal $output)]>; 928 929// Replace 930// OneHot(index, depth, on=1.0f, off=0.0f, axis=-1) * filter 931// With 932// EmbeddingLookup(index, Transpose(filter)) 933// 934// OneHot with on=1 off=0 axis=-1, where `index` is a single element tensor, 935// creates a tensor of size depth, and all values are 0, except for the element 936// at `index`, which is 1. Multiplying such a tensor with a 2D filter esentially 937// returns the single column in filter as a 1D tensor. If the input has multiple 938// elements, repeat this for every entry, forming the higher dimensions in the 939// result tensor. For instance, if: 940// input = [1, 2] 941// depth = 4 942// filter = [[5, 7, 11, 13], [17, 19, 23, 29]] 943// then: 944// onehot = [[0, 1, 0, 0], [0, 0, 1, 0]] 945// result = [[ 7, 19], # == 1st column in filter 946// [11, 23]] # == 2nd column in filter 947// This is exactly what the EmbeddedLookup operator is doing, on the transposed 948// matrix, without doing any arithmetic but only memcpy. 949def ReplaceOneHotFullyConnectedWithLookup : Pat< 950 (TFL_FullyConnectedOp 951 (TFL_OneHotOp 952 $indices, 953 (ConstantOp $depth), 954 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), 955 (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">), 956 ConstantAttr<I32Attr, "-1">), 957 $filter, 958 $bias, 959 TFL_AF_None, 960 TFL_FCWO_Default, 961 ConstBoolAttrFalse), 962 (TFL_EmbeddingLookupOp 963 $indices, 964 (TFL_TransposeOp 965 $filter, 966 (ConstantOp ConstantAttr<RankedI32ElementsAttr<[2]>, "{1,0}"> ))), 967 [(I32ElementsVal $indices), // lookup is not implemented for i64 968 (HasRank<1> $indices), // lookup isn't implemented for any other rank 969 (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later 970