• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the optimization pattern definition file for TensorFlow Lite.
17
18include "mlir/IR/OpBase.td"
19include "mlir/Dialect/StandardOps/IR/Ops.td"
20include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
21include "tensorflow/compiler/mlir/lite/utils/utils.td"
22include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
23
24// Checks if the param passed is a F32 ElementsAttr.
25def F32ElementsAttr : ElementsAttrBase<
26  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
27        "32 bit float constant tensor">;
28
29// Checks if the param passed is a float ElementsAttr.
30def FloatElementsAttr : ElementsAttrBase<
31  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
32        "float constant tensor">;
33
34// Checks if the param passed is of NoneType.
35def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
36
37def ExtractSingleElementAsFloat : NativeCodeCall<
38    "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
39
40// Checks if the value has rank at most 'n'.
41class HasRankAtMost<int n> : Constraint<
42    CPred<"$0.getType().cast<ShapedType>().hasRank() && "
43          "$0.getType().cast<ShapedType>().getRank() <= " # n>>;
44
45// Checks if the value has rank 'n'.
46class HasRank<int n> : Constraint<
47    CPred<"$0.getType().cast<ShapedType>().hasRank() && "
48          "$0.getType().cast<ShapedType>().getRank() == " # n>>;
49
50//===----------------------------------------------------------------------===//
51// Ternary ops patterns.
52//===----------------------------------------------------------------------===//
53// Multi-pattern consisting of matching stand-alone convolution op followed by
54// activation op.
55multiclass FuseActFnIntoConvOpPat<Op ActFnOp, Attr ActFnAttr> {
56  def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
57    (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
58                 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
59    (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
60        $padding, $stride_h, $stride_w),
61    [(HasOneUse $conv_out)]>;
62  def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
63    (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
64                $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
65                $multiplier)),
66    (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
67        ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
68    [(HasOneUse $conv_out)]>;
69}
70
71multiclass FuseActFnIntoPoolOpPat<Op ActFnOp, Attr ActFnAttr> {
72  def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat<
73    (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height,
74                  $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)),
75    (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding,
76        $stride_h, $stride_w, ActFnAttr),
77    [(HasOneUse $pool_out)]>;
78  def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat<
79    (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h,
80                  $filter_width, $filter_height, TFL_AF_None)),
81    (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h,
82        $filter_width, $filter_height, ActFnAttr),
83    [(HasOneUse $pool_out)]>;
84}
85
86// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
87// activation functions.
88// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
89// those cannot be simply translated into clamping.
90foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
91                     [TFL_Relu6Op, TFL_AF_Relu6],
92                     [TFL_Relu1Op, TFL_AF_Relu1]] in {
93  defm : FuseActFnIntoConvOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>;
94  defm : FuseActFnIntoPoolOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>;
95}
96
97class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
98  CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
99
100// If we see a binary op (add, sub) op adding a constant value to a convolution
101// op with constant bias, we can fuse the binary op into the convolution op by
102// constant folding the bias and the binary op's constant operand. The following
103// pattern restricts to float constant values for now.
104multiclass FuseBinaryOpToPrecedingAffine<Op binaryOp> {
105  def FuseBinaryOpWithConv#binaryOp : Pat<
106    (binaryOp (TFL_Conv2DOp:$output $input, $filter,
107                (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
108                TFL_AF_None, $padding, $stride_h, $stride_w),
109              (ConstantOp FloatElementsAttr:$value), $act_fn),
110    (TFL_Conv2DOp $input, $filter,
111      (binaryOp (ConstantOp $bias),
112         (ConstantOp $value), TFL_AF_None),
113      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
114    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
115     (HasOneUse $output)]>;
116  def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
117    (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
118                (ConstantOp FloatElementsAttr:$bias),
119                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
120                $stride_w, $multiplier),
121              (ConstantOp FloatElementsAttr:$value), $act_fn),
122    (TFL_DepthwiseConv2DOp $input, $filter,
123      (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
124      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
125      $multiplier),
126    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
127     (HasRank<1> $value),
128     (HasOneUse $output)]>;
129   def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
130    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
131                (ConstantOp FloatElementsAttr:$bias), $padding,
132                $stride_h, $stride_w),
133              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
134    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
135      (binaryOp (ConstantOp $bias),
136         (ConstantOp $value), TFL_AF_None),
137      $padding, $stride_h, $stride_w),
138    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
139     (HasOneUse $output)]>;
140  // Fuse for TransposeConv with no bias
141  def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat<
142    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
143                (ConstantOp $bias), $padding,
144                $stride_h, $stride_w),
145              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
146    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
147      (ConstantOp $value),
148      $padding, $stride_h, $stride_w),
149    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
150     (IsNoneType $bias),
151     (HasOneUse $output)]>;
152}
153foreach binaryOp = [TFL_AddOp, TFL_SubOp]<Op> in
154  defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
155
156def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
157
158def ExpandTo4DForDepthwiseConv: NativeCodeCall<
159  "ExpandTo4DForDepthwiseConv($0)">;
160
161// If we see a (div or Mul) op (dividing/multiplying) a constant value
162// to a convolution op with constant filter and bias, we can fuse the div/mul
163// into the convolution op by constant folding
164// the filter/bias and the div/mul op's constant operand.
165// The following pattern restricts to float constant values for now.
166
167multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<Op BinaryOp> {
168  def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
169    (BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
170                (ConstantOp FloatElementsAttr:$filter),
171                (ConstantOp FloatElementsAttr:$bias),
172                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
173                $stride_w, $multiplier),
174              (ConstantOp FloatElementsAttr:$value), $act_fn),
175    (TFL_DepthwiseConv2DOp $input,
176      (BinaryOp
177        (ConstantOp $filter),
178        (ConstantOp (ExpandTo4DForDepthwiseConv $value)),
179        TFL_AF_None),
180      (BinaryOp
181        (ConstantOp $bias),
182        (ConstantOp $value),
183        TFL_AF_None),
184      $h_factor, $w_factor, $act_fn, $padding, $stride_h,
185      $stride_w, $multiplier),
186    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
187     (HasRank<1> $value),
188     (HasOneUse $output)]>;
189  def FuseMulOrDivWithConv#BinaryOp : Pat<
190    (BinaryOp (TFL_Conv2DOp:$conv_output $input,
191                (ConstantOp FloatElementsAttr:$filter),
192                (ConstantOp FloatElementsAttr:$bias),
193                $h_factor, $w_factor, TFL_AF_None,
194                $padding, $stride_h, $stride_w),
195              (ConstantOp FloatElementsAttr:$value), $act_fn),
196    (TFL_Conv2DOp $input,
197      (BinaryOp (ConstantOp $filter),
198        (ConstantOp (ExpandTo4DForConv $value)),
199        TFL_AF_None),
200      (BinaryOp (ConstantOp $bias),
201        (ConstantOp $value),
202        TFL_AF_None),
203      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
204    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
205     (HasOneUse $conv_output)]>;
206  def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
207    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
208                (ConstantOp FloatElementsAttr:$weights), $input,
209                (ConstantOp FloatElementsAttr:$bias),
210                $padding, $stride_h, $stride_w),
211              (ConstantOp $value), TFL_AF_None),
212    (TFL_TransposeConvOp $output_shape,
213      (BinaryOp (ConstantOp $weights),
214        (ConstantOp (ExpandTo4DForConv $value)),
215        TFL_AF_None),
216      $input,
217      (BinaryOp (ConstantOp $bias),
218        (ConstantOp $value),
219        TFL_AF_None),
220      $padding, $stride_h, $stride_w),
221    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
222     (HasOneUse $output)]>;
223  def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
224    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
225                (ConstantOp FloatElementsAttr:$weights), $input,
226                (ConstantOp $bias),
227                $padding, $stride_h, $stride_w),
228              (ConstantOp $value), TFL_AF_None),
229    (TFL_TransposeConvOp $output_shape,
230      (BinaryOp (ConstantOp $weights),
231        (ConstantOp (ExpandTo4DForConv $value)),
232        TFL_AF_None),
233      $input,
234      (ConstantOp $bias),
235      $padding, $stride_h, $stride_w),
236    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
237     (IsNoneType $bias),
238     (HasOneUse $output)]>;
239}
240
241foreach BinaryOp = [TFL_DivOp, TFL_MulOp]<Op> in
242  defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>;
243
244
245// This pattern applies when the same quantize/dequantize have been used twice
246// with the same scale. We want to remove the redundancy.
247// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
248def eliminate_dq_q_pairs : Pat<
249  (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
250  (replaceWithValue $in),
251  [(NotFromQuantOpOrSameQuantType $in, $qt)]>;
252
253// Matching HardSwish
254def MatchHardSwishPattern1 : Pat<
255  (TFL_MulOp
256    (TFL_MulOp
257     $x, (TFL_AddOp
258          $x,
259          (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
260          TFL_AF_Relu6),
261     TFL_AF_None),
262    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
263     TFL_AF_None),
264  (TFL_HardSwishOp $x)>;
265
266def MatchHardSwishPattern2 : Pat<
267  (TFL_MulOp
268    $x,
269    (TFL_MulOp
270     (TFL_AddOp
271      $x,
272      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
273      TFL_AF_Relu6),
274     (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
275     TFL_AF_None),
276     TFL_AF_None),
277  (TFL_HardSwishOp $x)>;
278
279def MatchHardSwishPattern3 : Pat<
280  (TFL_MulOp
281    (TFL_MulOp
282     $x,
283     (TFL_AddOp
284      $x,
285      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
286      TFL_AF_Relu6),
287     TFL_AF_None),
288    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
289    TFL_AF_None),
290  (TFL_HardSwishOp $x)>;
291
292def MatchHardSwishPattern4 : Pat<
293  (TFL_MulOp
294    (TFL_MulOp
295     (TFL_AddOp
296      $x,
297      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
298      TFL_AF_Relu6),
299     (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
300     TFL_AF_None),
301    $x,
302    TFL_AF_None),
303  (TFL_HardSwishOp $x)>;
304
305// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
306// incorrect placement in the quantization aware training.
307def MatchHardSwishQuantized : Pat<
308  (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
309    (TFL_MulOp
310     $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
311          $x,
312          (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
313          TFL_AF_Relu6), $qattr2)),
314     TFL_AF_None), $qattr1)),
315    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
316     TFL_AF_None),
317  (TFL_HardSwishOp $x)>;
318
319// Constraint that the attribute value is less than 'n'
320class ConstDoubleValueLessThan<string n> : Constraint<
321  CPred<"$0.isa<DenseElementsAttr>() && "
322  "$0.cast<DenseElementsAttr>().getNumElements() == 1 && "
323  "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < "
324  # n>>;
325
326def L2NormValidReduceIndex : Constraint<CPred<
327  "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>;
328
329// Currently L2Normalization doesn't support activation function
330// in TFLite.
331// TODO(karimnosseir): Add constraints that the kernel code assumes.
332// constraint on axis and depth.
333multiclass L2NormalizePatterns<Op FirstOp, Op SecondOp> {
334  // This pattern constructs L2NormalizationOp from
335  // Mul->Rsqrt->Sum->Square Or
336  // Div->sqrt->Sum->Square
337  def L2NormalizePattern1#FirstOp#SecondOp : Pat<
338                  (FirstOp $x,
339                     (SecondOp
340                        (TFL_SumOp
341                           (TFL_SquareOp:$sq_op $x),
342                           (ConstantOp I32ElementsAttr:$axis),
343                           $keep_dims)),
344                     TFL_AF_None),
345           (TFL_L2NormalizationOp $x, TFL_AF_None),
346           [(L2NormValidReduceIndex $sq_op, $axis)]>;
347
348  // Below patterns for L2Normalize when there is an Add or Maximum
349  // adding or clamping to a small constant scalar.
350  def L2NormalizePattern2#FirstOp#SecondOp : Pat<
351                    (FirstOp $x,
352                     (SecondOp
353                      (TFL_AddOp
354                       (TFL_SumOp
355                        (TFL_SquareOp:$sq_op $x),
356                        (ConstantOp I32ElementsAttr:$axis),
357                        $keep_dims),
358                       (ConstantOp $epsilon), TFL_AF_None)),
359           TFL_AF_None),
360           (TFL_L2NormalizationOp $x, TFL_AF_None),
361           [(L2NormValidReduceIndex $sq_op, $axis),
362            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
363
364  def L2NormalizePattern3#FirstOp#SecondOp : Pat<
365                    (FirstOp $x,
366                     (SecondOp
367                      (TFL_MaximumOp
368                       (TFL_SumOp
369                        (TFL_SquareOp:$sq_op $x),
370                        (ConstantOp I32ElementsAttr:$axis),
371                        $keep_dims),
372                       (ConstantOp $epsilon))),
373           TFL_AF_None),
374           (TFL_L2NormalizationOp $x, TFL_AF_None),
375           [(L2NormValidReduceIndex $sq_op, $axis),
376            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
377
378}
379
380foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
381  in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
382
383//===----------------------------------------------------------------------===//
384// Binary ops patterns.
385//===----------------------------------------------------------------------===//
386def AreBroadcastableTypes : Constraint<CPred<
387  "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
388
389def OperandsBroadcastToOutputType : Constraint<CPred<
390  "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), "
391                                     "$2.getType())">>;
392
393def IsTailOfShape : Constraint<CPred<
394  "TFL::IsTailOfShape($0.getType(), $1.getType())">>;
395
396def Flatten : NativeCodeCall<
397  "$0.cast<DenseElementsAttr>()"
398    ".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, "
399                                   "$0.getType().cast<ShapedType>().getElementType()))">;
400
401def IsLastDimEqualToNumElements : Constraint<CPred<
402  "$0.getType().cast<ShapedType>().getRank() >= 1 && "
403  "$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == "
404  "$1.getType().cast<ShapedType>().getNumElements()">>;
405
406def IsDefinedByFullyConnectedOp : Constraint<CPred<
407  "$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>;
408
409// Pattern for skipping Tile if it is mainly for broadcasting and the
410// Op is already supporting broadcasting.
411multiclass FuseTileBroadcastIntoFollowingBinary<Op BinaryOp> {
412  def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
413    (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
414     $operand, $act_func),
415    (BinaryOp $input, $operand, $act_func),
416  [(OperandsBroadcastToOutputType $input, $operand, $result),
417   (HasRankAtMost<4> $input),
418   (HasRankAtMost<4> $operand)]>;
419
420  def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
421    (BinaryOp:$result $operand,
422      (TFL_TileOp $input, (ConstantOp $tile)), $act_func),
423    (BinaryOp $operand, $input, $act_func),
424  [(OperandsBroadcastToOutputType $operand, $input, $result),
425   (HasRankAtMost<4> $operand),
426   (HasRankAtMost<4> $input)]>;
427}
428
429// Multi-pattern consisting of matching stand-alone op or op followed by relu.
430multiclass FusedBinaryActivationFuncOpPat<Op BinaryOp> {
431  foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
432                       [TFL_Relu6Op, TFL_AF_Relu6],
433                       [TFL_Relu1Op, TFL_AF_Relu1]] in {
434    def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
435      (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
436      (BinaryOp $lhs, $rhs, actFnPair[1]),
437    [(HasOneUse $binary_out)]>;
438  }
439}
440
441foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
442  defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
443
444  // Instantiated FusedBinary patterns for the from-to pairs of ops.
445  defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
446
447  // Move binary op before reshape: reshape -> binary => binary -> reshape.
448  // This is valid only when the binary operand is constant and the shape is the
449  // tail of the other operand and the intermediate result isn't used by other
450  // ops.
451  // $rhs is required to be the tail shape of $lhs, so after transformation the
452  // shape of the binary op result is valid. For example, assume the shapes of
453  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
454  // transformation, the shape of the binary op result is [40x1600], which
455  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
456  // make sure $rhs is the tail shape of $lhs.
457  def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
458    (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
459      (ConstantOp:$rhs $a), $act_fn),
460    (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
461    // The broadcasting of "BinaryOp" only happens in the lower
462    // dimensions, and the higher dimensions are same, so we know the
463    // result and input of the "BinaryOp" in the source pattern have
464    // the same shape, which is defined by `shape`.
465    [(IsTailOfShape $rhs, $lhs),
466     (HasOneUse $lhs),
467     // The result of the new "BinaryOp" will have the same shape as
468     // `input`. In other words, the shape of the `Reshape` op are not
469     // changed after the transformation.
470     (IsTailOfShape $rhs, $input),
471     (HasRankAtMost<4> $input),
472     (HasRankAtMost<4> $lhs),
473     (HasRankAtMost<4> $rhs)]>;
474
475    // Move binary op before reshape:
476    // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
477    // This is valid only when both side of the binary operand is reshaped, and
478    // the sizes are the same both before and after the reshape.
479    def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
480      (BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)),
481                (TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2)),
482                $act_fn),
483      (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1),
484      [(IsTailOfShape $rhs, $lhs),
485       (IsTailOfShape $lhs, $rhs),
486       (IsTailOfShape $input1, $input2),
487       (IsTailOfShape $input2, $input1)]>;
488
489    // Move binary op before reshape:
490    // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
491    // This is valid only when the last dimension of lhs is equal to the
492    // number of elements in constant rhs.
493    // Therefore, after transformation broadcast of binary op is always
494    // applied to the last dimension of $input.
495    def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
496      (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
497                (ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn),
498      (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)),
499                               $act_fn),
500                     $shape),
501      [(AnyStaticShapeTensor $input),
502       (IsTailOfShape $rhs, $lhs),
503       (IsLastDimEqualToNumElements $input, $rhs),
504       (HasOneUse $lhs),
505       // Restrict operands to have at most rank 4 because TFLite binary
506       // kernel supports up to 4D broadcast.
507       (HasRankAtMost<4> $input),
508       (HasRankAtMost<4> $lhs),
509       (HasRankAtMost<4> $rhs),
510       (IsDefinedByFullyConnectedOp $input)]>;
511}
512
513foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
514                    TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
515                    TFL_GreaterEqualOp] in {
516  // Move binary op before reshape: reshape -> binary => binary -> reshape.
517  // This is valid only when the binary operand is constant and the shape is the
518  // tail of the other operand and the intermediate result isn't used by other
519  // ops.
520  // $rhs is required to be the tail shape of $lhs, so after transformation the
521  // shape of the binary op result is valid. For example, assume the shapes of
522  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
523  // transformation, the shape of the binary op result is [40x1600], which
524  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
525  // make sure $rhs is the tail shape of $lhs.
526  def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
527    (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
528      (ConstantOp:$rhs $a)),
529    (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
530    // The broadcasting of "BinaryOp" only happens in the lower
531    // dimensions, and the higher dimensions are same, so we know the
532    // result and input of the "BinaryOp" in the source pattern have
533    // the same shape, which is defined by `shape`.
534    [(IsTailOfShape $rhs, $lhs),
535     (HasOneUse $lhs),
536     // The result of the new "BinaryOp" will have the same shape as
537     // `input`. In other words, the shape of the `Reshape` op are not
538     // changed after the transformation.
539     (IsTailOfShape $rhs, $input),
540     (HasRankAtMost<4> $input),
541     (HasRankAtMost<4> $lhs),
542     (HasRankAtMost<4> $rhs)]>;
543
544    // Move binary op before reshape:
545    // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
546    // This is valid only when both side of the binary operand is reshaped, and
547    // the sizes are the same both before and after the reshape.
548    def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
549      (BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)),
550                (TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2))),
551      (TFL_ReshapeOp (BinaryOp $input1, $input2), $shape1),
552      [(IsTailOfShape $rhs, $lhs),
553       (IsTailOfShape $lhs, $rhs),
554       (IsTailOfShape $input1, $input2),
555       (IsTailOfShape $input2, $input1)]>;
556
557    // Move binary op before reshape:
558    // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
559    // This is valid only when the last dimension of lhs is equal to the
560    // number of elements in constant rhs.
561    // Therefore, after transformation broadcast of binary op is always
562    // applied to the last dimension of $input.
563    def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
564      (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
565                (ConstantOp:$rhs ElementsAttr:$rhs_attr)),
566      (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))),
567                     $shape),
568      [(AnyStaticShapeTensor $input),
569       (IsTailOfShape $rhs, $lhs),
570       (IsLastDimEqualToNumElements $input, $rhs),
571       (HasOneUse $lhs),
572       // Restrict operands to have at most rank 4 because TFLite binary
573       // kernel supports up to 4D broadcast.
574       (HasRankAtMost<4> $input),
575       (HasRankAtMost<4> $lhs),
576       (HasRankAtMost<4> $rhs),
577       (IsDefinedByFullyConnectedOp $input)]>;
578}
579
580// Reorder the element-wise value operations and the element move operations,
581// such that the value operation happens before move operation.
582foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
583                   TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
584                   TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp, TFL_LogisticOp] in {
585  foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
586                   TFL_ReshapeOp, TFL_TransposeOp] in {
587    def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
588      (ValueOp:$value (MoveOp:$move $input, $move_def)),
589      (MoveOp (ValueOp $input), $move_def),
590      [(SameElementType $input, $value), (HasOneUse $move)]>;
591  }
592}
593
594// Returns shape of a ranked tensor.
595// if called without a ranked tensor it will fail.
596def GetShape: NativeCodeCall<"GetShape($0)">;
597
598// Returns True if the operand type is RankedTensorType and valid.
599def HasValidRankedTensor : Constraint<CPred<
600  "$0.getType().isa<RankedTensorType>() && "
601  "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>;
602
603def ConvertSqueezeToReshape : Pat<
604  (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
605  (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
606  [(HasValidRankedTensor $squeeze_op)]>;
607
608// Convert expand_dims to reshape if possible.
609def ConvertExpandDimsToReshape : Pat<
610  (TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
611  (TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))),
612  [(AnyStaticShapeTensor $expand_dims_op)]>;
613
614class FloatValueEquals<string val> : Constraint<CPred<
615  "FloatValueEquals($0, " # val # ")">>;
616
617// ReLU patterns
618def MatchReluPattern : Pat<
619  (TFL_MaximumOp $input, (ConstantOp $Zero)),
620  (TFL_ReluOp $input),
621  [(FloatValueEquals<"0"> $Zero)]>;
622
623def MatchRelu1Pattern1 : Pat<
624  (TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
625    (ConstantOp $One)),
626  (TFL_Relu1Op $input),
627  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
628
629def MatchRelu1Pattern2 : Pat<
630  (TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)),
631    (ConstantOp $NegOne)),
632  (TFL_Relu1Op $input),
633  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
634
635def MatchLeakyRelu : Pat<
636  (TFL_MaximumOp
637    (TFL_MulOp:$mul_out $x,
638     (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
639    $x),
640  (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
641  [(ConstDoubleValueLessThan<"1"> $alpha),
642   (HasOneUse $mul_out)]>;
643
644// Returns True if all users of this operation are in TF/TFL and don't need
645// shape exact matching. This prevents from removing cast on return values which
646// can break the verifier on function type mismatch.
647def AllUsersInTF : Constraint<CPred<[{
648  llvm::all_of($0.getUsers(), [&](Operation *user) {
649    auto name = user->getName().getDialectNamespace();
650    return name == "tf" || name == "tfl";
651  })
652  }]>, "all users are TF/TFL operations.">;
653
654def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input),
655                            (replaceWithValue $input),
656                            [(SameElementType $input, $output),
657                             (AllUsersInTF $output)]>;
658
659
660// Checks if the operand0's rank is one less than operand1's rank.
661def PReluAlphaRankCheck : Constraint<
662  CPred<"$0.getType().cast<ShapedType>().getRank() == "
663  "$1.getType().cast<ShapedType>().getRank() - 1">>;
664
665// PReLU pattern from Keras:
666// f(x) = Relu(x) + (-alpha * Relu(-x))
667def MatchPRelu : Pat<
668  (TFL_AddOp
669   (TFL_ReluOp:$relu_out $x),
670   (TFL_MulOp:$mul_out
671    (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
672    $neg_alpha,
673    TFL_AF_None),
674   TFL_AF_None),
675  (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
676  [(PReluAlphaRankCheck $neg_alpha, $x),
677   (HasOneUse $relu_out),
678   (HasOneUse $mul_out),
679   (HasOneUse $input_neg_out)]>;
680
681// The constant folding in this pass might produce constant in the tf dialect.
682// This rule is to legalize these constant to the tfl dialect.
683def LegalizeConstOp : Pat<
684  (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
685
686// Reorders adds to allow constant folding.
687// Add --> Add $input, $constantA
688//    \--> $constantB
689// To
690// Add --> $input
691//    \--> Add ($constantA, $constantB)
692foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
693  def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
694    (TFL_AddOp
695     (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
696     (ConstantOp $b), ActFun),
697    (TFL_AddOp $input,
698     (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
699     ActFun),
700    [(HasOneUse $first_output),
701     (HasRankAtMost<4> $input),
702     (HasRankAtMost<4> $a),
703     (HasRankAtMost<4> $b)]>;
704}
705
706// We can eliminate Relu from Relu(SquaredDifference(x, y)),
707// since the result of SquaredDifference is always non-negative.
708// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
709// are failing without the following pattern to optimize Relu away fixes
710// the problem.
711def OptimizeReluSquaredDifference : Pat<
712  (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)),
713  (TFL_SquaredDifferenceOp $l, $r)>;
714
715// Optimize X^1 o X
716def OptimizePow1ToIdentity : Pat<
717  (TFL_PowOp $input,
718    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)),
719  (replaceWithValue $input)>;
720
721// Optimize X^2 to X*X
722def OptimizePow2ToSquare : Pat<
723  (TFL_PowOp $input,
724    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)),
725  (TFL_MulOp $input, $input, TFL_AF_None)>;
726
727// Optimize X^(1/2) to √X
728def OptimizePow2ToSqrt : Pat<
729  (TFL_PowOp $input,
730    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.5f">)),
731  (TFL_SqrtOp $input)>;
732
733// Optimize X^(-1/2) to 1/√X == rsqrt(x)
734def OptimizePow2ToRsqrt : Pat<
735  (TFL_PowOp $input,
736    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "-0.5f">)),
737  (TFL_RsqrtOp $input)>;
738
739def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred<
740  "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp("
741  "$0, $1.cast<DenseIntElementsAttr>())">>;
742
743def OptimizeIdentityGatherNdOp : Pat<
744  (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)),
745  (replaceWithValue $params),
746  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
747
748def OptimizeIdentityScatterNdOp : Pat<
749  (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored),
750  (replaceWithValue $params),
751  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
752
753def ShapeMatchesReduceWithKeepAxes : Constraint<CPred<
754  "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>;
755
756// Fold reshapes re-inserting reduced dimensions into the results of a reduction
757// with `keep_dims=false` by changing it to one using `keep_dims=true`.
758foreach ReduceOp = [TFL_MeanOp, TFL_ReduceMaxOp, TFL_ReduceMinOp,
759                    TFL_ReduceProdOp, TFL_SumOp] in {
760  def FoldReshapeTo#ReduceOp : Pat<
761    (TFL_ReshapeOp
762      (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes),
763                        ConstBoolAttrFalse),
764      (ConstantOp I32ElementsAttr: $shape)),
765    (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue),
766    [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape),
767     (HasOneUse $reduce)]>;
768}
769
770
771def IsSame : Constraint<CPred<"$0 == $1">>;
772def HasTwoUse : Constraint<CPred<
773  "std::distance($0.use_begin(), $0.use_end()) == 2">>;
774def AxesIsLastDimension : Constraint<CPred<
775  "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
776  "($0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == "
777  "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValue<int32_t>({0}) == -1)">>;
778
779// Convert exp(x)/sum(exp(x)) into softmax.
780def OptimizeToSoftmax : Pat<
781  (TFL_DivOp (TFL_ExpOp:$exp $input),
782             (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes),
783                             ConstBoolAttrTrue), TFL_AF_None),
784  (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
785  [(IsSame $exp, $sum_input),
786   (AxesIsLastDimension $axes, $sum_input),
787   (HasTwoUse $exp),
788   (HasOneUse $sum)]>;
789
790// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
791// with the max normalization.
792def FoldNormalizationIntoSoftmax : Pat<
793  (TFL_SoftmaxOp
794    (TFL_SubOp:$sub $input,
795      (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes),
796                            ConstBoolAttrTrue),
797    TFL_AF_None),
798    $beta),
799  (TFL_SoftmaxOp $input, $beta),
800  [(IsSame $input, $max_input),
801   (AxesIsLastDimension $axes, $max_input),
802   (HasOneUse $sub),
803   (HasOneUse $max)]>;
804
805def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>;
806
807class AllElementsAreF32<string val> : Constraint<CPred<
808  "($0.isa<DenseElementsAttr>() && "
809   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && "
810   "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), "
811               "$0.cast<DenseElementsAttr>().getValues<float>().end(), "
812               "[](float v){ return v == " #val# ";}))">>;
813
814// Optimize X*1 to X
815def OptimizeMul1ToIdentity : Pat<
816  (TFL_MulOp:$result $input,
817             (ConstantOp $constant),
818             TFL_AF_None),
819  (replaceWithValue $input),
820  [(HaveSameType $input, $result),
821   (AllElementsAreF32<"1.0f"> $constant)]>;
822
823class AllElementsAreBool<string val> : Constraint<CPred<
824  "($0.isa<DenseElementsAttr>() && "
825   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && "
826   "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), "
827               "$0.cast<DenseElementsAttr>().getValues<bool>().end(), "
828               "[](bool v){ return v == " #val# ";}))">>;
829
830// Remove select operators when the result is known in advance.
831foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
832  // select(true_tensor, A, B) -> A
833  def Optimize#SelectOp#True : Pat<
834    (SelectOp:$result (ConstantOp $constant),
835                      $input1,
836                      $input2),
837    (replaceWithValue $input1),
838    [(HaveSameType $input1, $result),
839     (AllElementsAreBool<"true"> $constant)]>;
840  // select(false_tensor, A, B) -> B
841  def Optimize#SelectOp#False : Pat<
842    (SelectOp:$result (ConstantOp $constant),
843                      $input1,
844                      $input2),
845    (replaceWithValue $input2),
846    [(HaveSameType $input2, $result),
847     (AllElementsAreBool<"false"> $constant)]>;
848}
849
850// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic.
851foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
852  def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat<
853    (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta),
854                 (ConstantOp:$const_axes I32ElementsAttr:$axes)),
855    (ArgMinMaxOp $logits, $const_axes),
856    [(HasOneUse $softmax),
857     (AxesIsLastDimension $axes, $logits)]>;
858
859  def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat<
860    (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits),
861                 (ConstantOp:$const_axes I32ElementsAttr:$axes)),
862    (ArgMinMaxOp $logits, $const_axes),
863    [(HasOneUse $log_softmax),
864     (AxesIsLastDimension $axes, $logits)]>;
865}
866
867def CanOptimizeIdentitySliceOp : Constraint<CPred<
868  "TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>;
869
870// Remove Slice ops slicing the whole input tensor, effectively no-op.
871def OptimizeSliceOp : Pat<
872  (TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)),
873  (replaceWithValue $input),
874  [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>;
875
876def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">;
877
878def IsLastElementEqualsOne : Constraint<CPred<
879  "TFL::IsLastElementEqualsOne($0)">>;
880
881def IsOneHotIndexAttribute : Constraint<CPred<
882  "TFL::IsOneHotIndexAttribute($0)">>;
883
884// Replace
885//   Equal(Reshape(X, shape), indices)
886// With
887//   OneHot(X, N, true, false, -1)
888// where
889//  - last value in shape is 1
890//  - indices is a incrementing series from 0 to N-1. (N elements total.)
891def ReshapeEqualOpToOneHotOp : Pat<
892  (TFL_EqualOp (TFL_ReshapeOp $x, (ConstantOp $shape)),
893               (ConstantOp $series)),
894  (TFL_OneHotOp $x,
895                (ConstantOp (GetNumElementsOrOne $series)),
896                (ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "true">),
897                (ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "false">),
898                ConstantAttr<I32Attr, "-1">),
899  [(IsLastElementEqualsOne $shape),
900   (IsOneHotIndexAttribute $series)]>;
901
902def F32ElementsVal : Constraint<CPred<
903  "$0.getType().cast<TensorType>().getElementType().isF32()">,
904  "32 bit float tensor">;
905def I32ElementsVal : Constraint<CPred<
906  "$0.getType().cast<TensorType>().getElementType().isInteger(32)">,
907  "32 bit integer tensor">;
908
909def ConvertSingleElementAttrToFloatAttr :
910  NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">;
911
912// Replace
913//   (float)OneHot(index, depth, on_val, off_val, axis)
914// With
915//   OneHot(index, depth, (float)on_val, (float)off_val, axis)
916def FuseOneHotAndCastToFloat : Pat<
917  (TFL_CastOp:$output (TFL_OneHotOp $indices,
918                                    $depth,
919                                    (ConstantOp $on_val),
920                                    (ConstantOp $off_val),
921                                    $axis)),
922  (TFL_OneHotOp $indices,
923                $depth,
924                (ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)),
925                (ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)),
926                $axis),
927  [(F32ElementsVal $output)]>;
928
929// Replace
930//   OneHot(index, depth, on=1.0f, off=0.0f, axis=-1) * filter
931// With
932//   EmbeddingLookup(index, Transpose(filter))
933//
934// OneHot with on=1 off=0 axis=-1, where `index` is a single element tensor,
935// creates a tensor of size depth, and all values are 0, except for the element
936// at `index`, which is 1. Multiplying such a tensor with a 2D filter esentially
937// returns the single column in filter as a 1D tensor. If the input has multiple
938// elements, repeat this for every entry, forming the higher dimensions in the
939// result tensor. For instance, if:
940//   input = [1, 2]
941//   depth = 4
942//   filter = [[5, 7, 11, 13], [17, 19, 23, 29]]
943// then:
944//   onehot = [[0, 1, 0, 0], [0, 0, 1, 0]]
945//   result = [[ 7, 19],   # == 1st column in filter
946//             [11, 23]]   # == 2nd column in filter
947// This is exactly what the EmbeddedLookup operator is doing, on the transposed
948// matrix, without doing any arithmetic but only memcpy.
949def ReplaceOneHotFullyConnectedWithLookup : Pat<
950  (TFL_FullyConnectedOp
951    (TFL_OneHotOp
952      $indices,
953      (ConstantOp $depth),
954      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
955      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">),
956      ConstantAttr<I32Attr, "-1">),
957    $filter,
958    $bias,
959    TFL_AF_None,
960    TFL_FCWO_Default,
961    ConstBoolAttrFalse),
962  (TFL_EmbeddingLookupOp
963    $indices,
964    (TFL_TransposeOp
965      $filter,
966      (ConstantOp ConstantAttr<RankedI32ElementsAttr<[2]>, "{1,0}"> ))),
967  [(I32ElementsVal $indices),     // lookup is not implemented for i64
968   (HasRank<1> $indices),  // lookup isn't implemented for any other rank
969   (IsNoneType $bias)]>;          // Maybe folded into the lookup matrix later
970