• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// LINT.IfChange
17// This is the operation definition file for TensorFlow Lite.
18
19#ifndef TFL_OPS
20#define TFL_OPS
21
22include "mlir/IR/OpBase.td"
23include "mlir/Interfaces/ControlFlowInterfaces.td"
24include "mlir/Interfaces/InferTypeOpInterface.td"
25include "mlir/Interfaces/LoopLikeInterface.td"
26include "mlir/Interfaces/SideEffectInterfaces.td"
27include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
28include "tensorflow/compiler/mlir/lite/ir/tfl_structs.td"
29include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
30include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
31
32//===----------------------------------------------------------------------===//
33// TFLite dialect string type - uses the TF string type as implementation
34//===----------------------------------------------------------------------===//
35def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
36                  "TFLite string type">,
37             BuildableType<"getType<mlir::TF::StringType>()">;
38
39//===----------------------------------------------------------------------===//
40// TFLite dialect quint8 type - uses the TF quint8 type as implementation
41//===----------------------------------------------------------------------===//
42def TFL_Quint8 : Type<CPred<"$_self.isa<mlir::TF::Quint8Type>()">,
43                    "TFLite quint8 type">,
44              BuildableType<"getType<mlir::TF::Quint8Type>()">;
45
46//===----------------------------------------------------------------------===//
47// Activation function enum definitions.
48//===----------------------------------------------------------------------===//
49
50// Allowed activation function cases
51// These should match the ActivationFunctionType enum in TFLite schema.
52def TFL_AF_None  : StrEnumAttrCase<"NONE">;
53def TFL_AF_Relu  : StrEnumAttrCase<"RELU">;
54def TFL_AF_Relu1 : StrEnumAttrCase<"RELU_N1_TO_1">;
55def TFL_AF_Relu6 : StrEnumAttrCase<"RELU6">;
56def TFL_AF_Tanh  : StrEnumAttrCase<"TANH">;
57def TFL_AF_Sign  : StrEnumAttrCase<"SIGN_BIT">;
58
59def TFL_AFAttr : StrEnumAttr<
60    "ActivationFunctionType", "fused activation enum", [
61      TFL_AF_None,  TFL_AF_Relu, TFL_AF_Relu1,
62      TFL_AF_Relu6, TFL_AF_Tanh, TFL_AF_Sign
63    ]>;
64
65//===----------------------------------------------------------------------===//
66// Padding enum definitions.
67//===----------------------------------------------------------------------===//
68
69// Allowed padding cases
70// These should match the padding enum in TFLite schema.
71def TFL_PAD_Same  : StrEnumAttrCase<"SAME">;
72def TFL_PAD_Valid : StrEnumAttrCase<"VALID">;
73def TFL_MIRRORPAD_Reflect : StrEnumAttrCase<"REFLECT">;
74def TFL_MIRRORPAD_Symmetric : StrEnumAttrCase<"SYMMETRIC">;
75
76def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
77      TFL_PAD_Same, TFL_PAD_Valid
78    ]>;
79
80def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [
81      TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric
82    ]>;
83
84//===----------------------------------------------------------------------===//
85// TensorType attribute definitions.
86//===----------------------------------------------------------------------===//
87// A type attribute containing the TensorType.
88def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
89
90// A type attribute containing OpaqueElementsAttr and bytes.
91def OpaqueBytesAttr : ElementsAttrBase<
92  And<[
93    CPred<"$_self.isa<OpaqueElementsAttr>() ">,
94    CPred<"$_self.cast<OpaqueElementsAttr>().getType()"
95          ".getElementType().isInteger(8)">,
96  ]>,
97  "opaque bytes attribute"
98 > {
99  let storageType = [{ OpaqueElementsAttr }];
100  let returnType = [{ OpaqueElementsAttr }];
101  let convertFromStorage = "$_self";
102}
103
104//===----------------------------------------------------------------------===//
105// Derived shape attribute class.
106//===----------------------------------------------------------------------===//
107class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
108class DerivedTFLiteTypeAttr<code body, code convert> :
109  DerivedAttr<"tflite::TensorType", body, convert>;
110
111// TFL Runtime op trait predicate.
112class TFL_RuntimePredOpTrait<string desc, Pred pred> :
113    GenInternalOpTrait<"TFLRuntimeOpTrait"> {
114  Pred tflRuntimePredicate = pred;
115  string tflRuntimeDescription = desc;
116}
117
118class TFL_OperandsHaveSameShapesOrBroadcastableShape<
119    list<int> indices, int max_bcast_rank> :
120  TFL_RuntimePredOpTrait<"operands do not have the same shape or "
121      "broadcastable shapes within the rank " # max_bcast_rank,
122    CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
123            "$_op, llvm::ArrayRef<unsigned>({" # !interleave(indices, ", ") #
124            "}), " # max_bcast_rank # ")">>;
125
126// These additional types/type constraints here are used to decouple the ops
127// from runtime support for the ops. Prefer to use these types when defining
128// new TF_Ops for uniformity.
129
130// TFL Runtime type predicate.
131class TFL_RuntimeType<TypeConstraint t> {
132  Pred tflRuntimeTypePredicate = t.predicate;
133  string tflRuntimeTypeDescription = t.summary;
134}
135
136class TFL_AnyTypeOf<list<Type> allowedRuntimeTypes, string description = "",
137                    list<Type> allowedOpTypes = [AnyType]> :
138  AnyTypeOf<allowedOpTypes, description>,
139  TFL_RuntimeType<AnyTypeOf<allowedRuntimeTypes, description>>;
140
141class TFL_TensorOf<list<Type> allowedRuntimeTypes,
142                   list<Type> allowedOpTypes = [AnyType]> :
143  TensorOf<allowedOpTypes>, TFL_RuntimeType<TensorOf<allowedRuntimeTypes>>;
144
145class TFL_TensorOfOrNone<list<Type> allowedRuntimeTypes, string description = "",
146                         list<Type> allowedOpTypes = [AnyType]> :
147  AnyTypeOf<[TFL_TensorOf<allowedOpTypes>, NoneType], description>,
148  TFL_RuntimeType<AnyTypeOf<[TFL_TensorOf<allowedRuntimeTypes>, NoneType]>>;
149
150class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
151                   list<Type> allowedOpTypes = [AnyType]> :
152  Variadic<TensorOf<allowedOpTypes>>,
153  TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
154
155def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
156
157def TFL_BoolTensor : TFL_TensorOf<[I1]>;
158def TFL_FpTensor : TFL_TensorOf<[F32]>;
159def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
160def TFL_I32Tensor : TFL_TensorOf<[I32]>;
161def TFL_I64Tensor : TFL_TensorOf<[I64]>;
162def TFL_Complex64Tensor : TFL_TensorOf<[Complex<F<32>>]>;
163def TFL_ResourceTensor : TFL_TensorOf<[TF_Resource]>;
164
165// TODO(jpienaar): Expand to all int types.
166def TFL_IntTensor : TypeAlias<TFL_I32Tensor, "tensor of any integer type">;
167
168class TFL_0DTensorOf<list<Type> allowedRuntimeTypes,
169                     list<Type> allowedOpTypes = [AnyType]> :
170  0DTensorOf<allowedOpTypes>, TFL_RuntimeType<TensorOf<allowedRuntimeTypes>>;
171class TFL_1DTensorOf<list<Type> allowedRuntimeTypes,
172                     list<Type> allowedOpTypes = [AnyType]> :
173  1DTensorOf<allowedOpTypes>, TFL_RuntimeType<TensorOf<allowedRuntimeTypes>>;
174class TFL_2DTensorOf<list<Type> allowedRuntimeTypes,
175                     list<Type> allowedOpTypes = [AnyType]> :
176  2DTensorOf<allowedOpTypes>, TFL_RuntimeType<TensorOf<allowedRuntimeTypes>>;
177
178class TFL_1DTensorOfOrNone<list<Type> allowedRuntimeTypes, string description = "",
179                         list<Type> allowedOpTypes = [AnyType]> :
180  AnyTypeOf<[TensorOf<allowedOpTypes>, NoneType], description>,
181  TFL_RuntimeType<AnyTypeOf<[TFL_1DTensorOf<allowedRuntimeTypes>, NoneType]>>;
182
183// This is used to represent the type of "ref tensors" or tensors that are
184// used as variables to track state.
185def TFL_StatefulTensor : TypeAlias<AnyTensor, "stateful tensor">;
186
187//===----------------------------------------------------------------------===//
188// Rank/Shape helpers.
189//===----------------------------------------------------------------------===//
190
191// Returns true of operand is none type.
192class TFL_OperandIsNoneType<int i> :
193  CPred<"$_op.getOperand(" # i # ").getType().isa<NoneType>()">;
194
195class TFL_OperandIsUnrankedPred<int n> :
196  CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
197
198// TODO: Some of these could be generalized and/or moved to more general
199// location.
200// Returns true if the n-th operand has unknown rank or has rank m.
201class TFL_OperandHasRank<int n, int m> :
202  PredOpTrait<"operand " # n # " is " # m # "-D",
203    Or<[TFL_OperandIsUnrankedPred<n>,
204      CPred<"$_op.getOperand(" # n #
205      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
206
207// Returns true if the n-th operand is ranked and has rank dim.
208class TFL_OperandHasKnownRank<int n, int dim> : And<[
209  CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
210  CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
211    # dim>]>;
212
213// True if operand n is ranked and has a rank > dim.
214class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
215  CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
216  CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
217  # dim>]>;
218
219// Returns true if the n-th operand is ranked and has a dimension length = size
220// at the rank dim.
221class TFL_OperandDimEquals<int n, int dim, int size> : And<[
222  TFL_OperandIsRankedAndHasDimPred<n, dim>,
223  CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
224      ".getShape()[" # dim # " ] == " # size>]>;
225
226// Returns true if the n-th operand is ranked and has a dimension length <=
227// size at the rank dim.
228class TFL_OperandDimIsAtMost<int n, int dim, int size> : And<[
229  TFL_OperandIsRankedAndHasDimPred<n, dim>,
230  CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
231      ".getShape()[" # dim # " ] <= " # size>]>;
232
233// Returns true if the n-th operand has unknown rank or at least rank m.
234class TFL_OperandHasAtleastRank<int n, int m> :
235  PredOpTrait<"operand " # n # " is " # m # "-D",
236    Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
237      CPred<"$_op.getOperand(" # n #
238        ").getType().cast<ShapedType>().getRank() >= " # m>]>>;
239
240class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
241  PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
242    Or<[TFL_OperandIsUnrankedPred<x>,
243        TFL_OperandIsUnrankedPred<y>,
244        CPred<"!$_op.getOperand(" # y #
245          ").getType().cast<ShapedType>().hasStaticShape()">,
246        CPred<"$_op.getOperand(" # x #
247          ").getType().cast<ShapedType>().getRank() == "
248          "$_op.getOperand(" # y #
249          ").getType().cast<ShapedType>().getShape()[0]">]>>;
250
251class TFL_Operand0DOr1ElementTensor<int x> :
252  PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
253    Or<[TFL_OperandHasKnownRank<x, 0>,
254        And<[TFL_OperandHasKnownRank<x, 1>, TFL_OperandDimEquals<x, 0, 1>]>]>>;
255
256// Return true if i-th dim of x-th operand is the same as j-th dim of y-th
257// operand or any of those operands does not have static shape.
258class TFL_OperandsHaveSameDims<int x, int y, int i, int j> :
259    Or<[TFL_OperandIsUnrankedPred<x>,
260        TFL_OperandIsUnrankedPred<y>,
261        CPred<"!$_op.getOperand(" # x #
262          ").getType().cast<ShapedType>().hasStaticShape()">,
263        CPred<"!$_op.getOperand(" # y #
264          ").getType().cast<ShapedType>().hasStaticShape()">,
265        CPred<"$_op.getOperand(" # x #
266          ").getType().cast<ShapedType>().getShape()[" # i # "] == "
267          "$_op.getOperand(" # y #
268          ").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
269
270class TFL_OperandsHaveSameDimsTrait<int x, int y, int i, int j> :
271  PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j #
272    " of operand " # y,
273    TFL_OperandsHaveSameDims<x, y, i, j>>;
274
275// Return true if number of elements of x-th operand is the same as j-th dim of
276// y-th operand or any of those operands does not have static shape.
277class TFL_NumElementsEqualsDim<int x, int y, int j> :
278  Or<[TFL_OperandIsUnrankedPred<x>,
279      TFL_OperandIsUnrankedPred<y>,
280      CPred<"!$_op.getOperand(" # x #
281        ").getType().cast<ShapedType>().hasStaticShape()">,
282      CPred<"!$_op.getOperand(" # y #
283        ").getType().cast<ShapedType>().hasStaticShape()">,
284      CPred<"$_op.getOperand(" # x #
285        ").getType().cast<ShapedType>().getNumElements() == "
286        "$_op.getOperand(" # y #
287        ").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
288
289class TFL_NumElementsEqualsDimTrait<int x, int y, int j> :
290  PredOpTrait<"operand " # x # " has num of elements equals to dim " # j #
291    " of operand " # y,
292    TFL_NumElementsEqualsDim<x, y, j>>;
293
294// Return true if number of elements of x-th operand equals to n.
295class TFL_NumElements<int x, int n> :
296  Or<[TFL_OperandIsUnrankedPred<x>,
297      CPred<"!$_op.getOperand(" # x #
298        ").getType().cast<ShapedType>().hasStaticShape()">,
299      CPred<"$_op.getOperand(" # x #
300        ").getType().cast<ShapedType>().getNumElements() == " # n>]>;
301
302class TFL_NumElementsTrait<int x, int n> :
303  PredOpTrait<"operand " # x # " has num of elements equals to  " # n,
304    TFL_NumElements<x, n>>;
305
306// tf.uint8 and tf.quint8 are mapped to the same tflite types, so they are equal
307// when used as element types.
308class TFL_TFTypesWithSameBits<int i, int j, int num> :
309  And<[
310    Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
311        CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>,
312    Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
313        CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
314
315class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
316  And<[
317    Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
318        CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>,
319    Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
320        CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
321
322class TFL_OperandIsNoneOrHasRank<int n, int m> :
323  PredOpTrait<"operand " # n # " is " # m # "-D",
324    Or<[
325      TFL_OperandIsNoneType<n>,
326      TFL_OperandIsUnrankedPred<n>,
327      CPred<"$_op.getOperand(" # n #
328      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
329
330class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
331  PredOpTrait<"operand " # n # " is at most " # m # "-D",
332    Or<[
333      TFL_OperandIsNoneType<n>,
334      TFL_OperandIsUnrankedPred<n>,
335      CPred<"$_op.getOperand(" # n #
336      ").getType().cast<ShapedType>().getRank() <= " # m>]>>;
337
338class TFL_OperandHasRankAtMost<int n, int m> :
339  PredOpTrait<"operand " # n # " is at most " # m # "-D",
340    Or<[TFL_OperandIsUnrankedPred<n>,
341      CPred<"$_op.getOperand(" # n #
342      ").getType().cast<ShapedType>().getRank() <= " # m>]>>;
343
344class TFL_OperandHasRankAtLeast<int n, int m> :
345  PredOpTrait<"operand " # n # " is at least " # m # "-D",
346    Or<[TFL_OperandIsUnrankedPred<n>,
347      CPred<"$_op.getOperand(" # n #
348      ").getType().cast<ShapedType>().getRank() >= " # m>]>>;
349
350class TFL_OperandHasRankRange<int n, int x, int y> :
351  PredOpTrait<"operand " # n # " has rank range [" # x # ", " # y # "]",
352    Or<[TFL_OperandIsUnrankedPred<n>,
353      CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() "
354      ">= " # x # " && $_op.getOperand(" # n # ").getType().cast<ShapedType>()."
355      "getRank() <= " # y>]>>;
356
357def TFL_FloatNonNegative : AttrConstraint<
358    CPred<"$_self.isa<FloatAttr>() && "
359            "!$_self.cast<FloatAttr>().getValue().isNegative()">,
360    "whose value is non-negative">;
361
362def TFL_BoolTrue : AttrConstraint<
363    CPred<"$_self.isa<BoolAttr>() && $_self.cast<BoolAttr>().getValue()">,
364    "whose value is true">;
365
366def TFL_BoolFalse : AttrConstraint<
367    CPred<"$_self.isa<BoolAttr>() && !$_self.cast<BoolAttr>().getValue()">,
368    "whose value is false">;
369
370class TFL_StringEqualsTo<string value> : AttrConstraint<
371    CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">,
372    "whose value equals to '" # value # "'">;
373
374// Ensures the array attribute's size is within the given maximum size.
375class TFL_ArrayMaxCount<int n> : AttrConstraint<
376    CPred<"$_self.isa<ArrayAttr>() && $_self.cast<ArrayAttr>().size() <= " # n>,
377    "whose size is at most " # n>;
378
379// Ensures the given integer attribute has the given value.
380class TFL_IntEqualsTo<int n> : AttrConstraint<
381    CPred<"$_self.isa<IntegerAttr>() && "
382            "$_self.cast<IntegerAttr>().getInt() == " # n>,
383    "whose value is " # n>;
384
385// This is a quantization-aware version of TCresVTEtIsSameAsOp
386class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
387  TCOpResIsShapedTypePred<i, j>,
388  Or<[
389    TCresVTEtIsSameAsOpBase<i, j>,
390    TFL_TFTypesWithSameBits<i, j, 8>,
391    And<[
392      SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
393        quant_QuantizedType.predicate>,
394      CPred<"quant::QuantizedType::castToStorageType("
395                "getElementTypeOrSelf($_op.getResult(" # i # "))) == "
396            "quant::QuantizedType::castToStorageType("
397                "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
398
399def TFL_SameFirstOperandAndFirstResultElementType :
400  PredOpTrait<"values and output must have same element type",
401              TFL_TCresVTEtIsSameAsOp<0, 0>>;
402
403// This is a quantization-aware version of TCopVTEtAreSameAt
404class TFL_TCopVTEtAreSameAt<int i, int j, int num=8> : Or<[
405  TCopVTEtAreSameAt<[i, j]>,
406  TFL_TFOperandTypesWithSameBits<i, j, num>,
407  And<[
408    SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
409      quant_QuantizedType.predicate>,
410    CPred<"quant::QuantizedType::castToStorageType("
411              "getElementTypeOrSelf($_op.getOperand(" # i # "))) == "
412          "quant::QuantizedType::castToStorageType("
413              "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>;
414
415//===----------------------------------------------------------------------===//
416// TFL op common constraints.
417//===----------------------------------------------------------------------===//
418
419class OperandsSameElementTypeConstraintBase<string op> :
420  PredOpTrait<op # " operands have same element type",
421    Or<[
422      TCopVTEtIsSameAs<0, 1>,
423      // Two operands' values are both quantized and their type have the same
424      // underlying storage type.
425      And<[
426        SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))",
427          quant_QuantizedType.predicate>,
428        CPred<"quant::QuantizedType::castToStorageType("
429                  "getElementTypeOrSelf($_op.getOperand(0))) == "
430              "quant::QuantizedType::castToStorageType("
431                  "getElementTypeOrSelf($_op.getOperand(1)))">]>]>>;
432
433// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
434// Binary ops lhs & rhs should have the same value type, and is capable to
435// compare quantization types as well.
436def BinaryOpSameElementTypeConstraint :
437  OperandsSameElementTypeConstraintBase<"binary op">;
438
439// This is a constraint for most of the comparison ops, e.g., equal, not_equal,
440// greater, greater_equal, less, etc. Comparison ops lhs & rhs should have the
441// same value type, and is capable to compare quantization types as well.
442def ComparisonOpSameElementTypeConstraint :
443  OperandsSameElementTypeConstraintBase<"comparison op">;
444
445//===----------------------------------------------------------------------===//
446// TFL common builders.
447//===----------------------------------------------------------------------===//
448
449def TFL_BroadcastableBinaryBuilder :
450  OpBuilder<(ins "Value":$lhs, "Value":$rhs),
451  [{
452    auto resultType =
453      OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
454    if (!resultType)
455      mlir::emitError($_state.location, "non-broadcastable operands");
456    $_state.addOperands({lhs, rhs});
457    $_state.types.push_back(resultType);
458  }]>;
459
460def TFL_FusedBroadcastableBinaryBuilder :
461  OpBuilder<(ins "Value":$lhs, "Value":$rhs,
462    "StringAttr":$fusedActivationFunction),
463  [{
464    buildFusedBroadcastableBinOp(
465       &$_builder, $_state, lhs, rhs, fusedActivationFunction);
466  }]>;
467
468def TFL_ComparisonBinaryBuilder :
469  OpBuilder<(ins "Value":$lhs, "Value":$rhs),
470  [{
471    buildComparisonBinOp(&$_builder, $_state, lhs, rhs);
472  }]>;
473
474//===----------------------------------------------------------------------===//
475// TFL op base class.
476//===----------------------------------------------------------------------===//
477
478class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
479    Op<TFL_Dialect, mnemonic, !listconcat(traits,
480      [DeclareOpInterfaceMethods<TFL_RuntimeVerification>])> {
481  // FlatBuffer generation specific information.
482  // -------------------------------------------
483  // When generating the FlatBuffer output some operations have
484  // Options (as defined in the schema). These options are effectively
485  // the attributes of the operations (e.g., what padding is to be used
486  // for a pooling operator). Not all operations have Options and some
487  // operations share Options. The following attributes indicate whether
488  // the operation has Options in the serialized FlatBuffer.
489
490  // Whether the TFLite operator has options in the schema representation.
491  bit hasOptions = 0b0;
492
493  // Use to specify a custom options type for TFLite operators where
494  // the option's name does not match the TFLite operator's name.
495  // If no customOption is specified then <name>Options is used if the op
496  // hasOptions.
497  string customOption = ?;
498}
499
500class TFL_ConvOp<string mnemonic, string opSummary, int index,
501                 list<OpTrait> additional_traits = []> :
502    TFL_Op<mnemonic,[NoSideEffect,
503                     AccumulatorUniformScale<2, 0, 1>,
504                     AffineQuantizedOpInterface,
505                     AffineOpCoefficient<index, 1>,
506                     TFL_SparseOp] # additional_traits> {
507  let summary = opSummary # " operator";
508
509  let description = [{
510    Performs convolution operation on inputs.
511
512    Inputs:
513      `inputs[0]`: required: the input activation tensor
514      `inputs[1]`: required: the filter weight tensor
515      `inputs[2]`: optional: the bias tensor
516  }];
517
518  let arguments = (
519    ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
520    TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
521    TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
522    I32Attr:$dilation_h_factor,
523    I32Attr:$dilation_w_factor,
524    TFL_AFAttr:$fused_activation_function,
525    TFL_PaddingAttr:$padding,
526    I32Attr:$stride_h,
527    I32Attr:$stride_w
528  );
529
530  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
531
532  let hasOptions = 0b1;
533}
534
535
536//===----------------------------------------------------------------------===//
537// TFL op definitions.
538//===----------------------------------------------------------------------===//
539def TFL_AbsOp : TFL_Op<"abs", [
540    NoSideEffect,
541    SameOperandsAndResultShape,
542    SameOperandsAndResultType,
543    SameOperandsAndResultsScale]> {
544  let summary = "Absolute value operator";
545
546  let description = [{
547Given a tensor `x`, this operation returns a tensor containing the absolute
548value of each element in `x`. For example, if x is an input element and y is
549an output element, this operation computes \\(y = |x|\\).
550  }];
551
552  let arguments = (ins TFL_TensorOf<[I16, F32, QI8, QI16]>:$x);
553
554  let results = (outs TFL_TensorOf<[I16, F32, QI8, QI16]>:$y);
555
556  let hasFolder = 1;
557}
558
559def TFL_AddOp : TFL_Op<"add", [
560    TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
561      CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
562    ResultsBroadcastableShape,
563    NoSideEffect,
564    Commutative,
565    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
566  let summary = "Addition operator";
567
568  let description = [{
569    Element-wise addition operation.
570  }];
571
572  let arguments = (
573    ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
574    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
575    TFL_AFAttr:$fused_activation_function);
576
577  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
578
579  let hasFolder = 1;
580
581  let builders = [TFL_FusedBroadcastableBinaryBuilder];
582
583  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
584
585  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
586
587  let hasOptions = 1;
588}
589
590def TFL_AddNOp : TFL_Op<"add_n", [
591    Commutative,
592    NoSideEffect,
593    SameOperandsAndResultsScale,
594    NoQuantizableResult,
595    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
596  let summary = "add_n operator";
597
598  let description = [{
599    Adds all input tensors element-wise.
600  }];
601
602  let arguments = (ins
603    TFL_VariadicTensorOf<[F32, I32]>:$inputs
604  );
605
606  let results = (outs
607    TFL_TensorOf<[F32, I32]>:$sum
608  );
609}
610
611def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [
612    NoSideEffect,
613    NoQuantizableResult]> {
614  let summary = [{
615Computes the "logical or" of elements across dimensions of a tensor.
616  }];
617
618  let description = [{
619Reduces `input` along the dimensions given in `axis`. Unless
620`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
621`axis`. If `keep_dims` is true, the reduced dimensions are
622retained with length 1.
623  }];
624
625  let arguments = (ins
626    TFL_BoolTensor:$input,
627    TFL_I32Tensor:$reduction_indices,
628
629    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
630  );
631
632  let results = (outs
633    TFL_BoolTensor:$output
634  );
635
636  let hasOptions = 1;
637  let customOption = "ReducerOptions";
638}
639
640def TFL_ReduceAllOp : TFL_Op<"reduce_all", [
641    NoSideEffect,
642    NoQuantizableResult]> {
643  let summary = [{
644Computes the "logical and" of elements across dimensions of a tensor.
645  }];
646
647  let description = [{
648Reduces `input` along the dimensions given in `axis`. Unless
649`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
650`axis`. If `keep_dims` is true, the reduced dimensions are
651retained with length 1.
652  }];
653
654  let arguments = (ins
655    TFL_BoolTensor:$input,
656    TFL_I32Tensor:$reduction_indices,
657
658    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
659  );
660
661  let results = (outs
662    TFL_BoolTensor:$output
663  );
664
665  let hasOptions = 1;
666  let customOption = "ReducerOptions";
667}
668
669def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
670    NoSideEffect,
671    TFL_OperandHasRank<0, 1>,
672    TFL_OperandHasRank<1, 4>,
673    TFL_OperandHasRank<2, 4>,
674    PredOpTrait<"input and output must have same element type",
675      TFL_TCresVTEtIsSameAsOp<0, 2>>,
676    AccumulatorUniformScale<3, 1, 2>,
677    AffineQuantizedOpInterface, AffineOpCoefficient<0, 1>,
678    TFL_SparseOp,
679    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
680  let summary = "Transpose convolution operator";
681
682  let description = [{
683    Performs transpose convolution operation on input.
684  }];
685
686  let arguments = (ins
687    TFL_I32Tensor:$output_shape,
688    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
689    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
690    TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
691    TFL_PaddingAttr:$padding,
692    Confined<I32Attr, [IntPositive]>:$stride_h,
693    Confined<I32Attr, [IntPositive]>:$stride_w
694  );
695
696  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
697
698  let hasOptions = 1;
699
700  let verifier = [{ return Verify(*this); }];
701
702  let extraClassDeclaration = [{
703    // AffineQuantizedOpInterface:
704    int GetChannelDimIndex() { return 0; }
705    int GetQuantizationDimIndex() { return 0; }
706    // SparseOpInterface:
707    std::vector<int> GetSparseOperands() { return {1}; }
708    std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
709    std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
710  }];
711}
712
713def TFL_AveragePool2DOp:
714    TFL_Op<"average_pool_2d",
715           [NoSideEffect,
716            SameOperandsAndResultsScale,
717            DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
718  let summary = "Average_pool_2d operator";
719
720  let description = [{
721    Performs average-pooling operation on input.
722  }];
723
724  let arguments = (
725    ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
726    I32Attr:$filter_height,
727    I32Attr:$filter_width,
728    TFL_PaddingAttr:$padding,
729    I32Attr:$stride_h,
730    I32Attr:$stride_w,
731    TFL_AFAttr:$fused_activation_function
732  );
733
734  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
735
736  let hasOptions = 1;
737  let customOption = "Pool2DOptions";
738}
739
740def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
741  let summary = "ArgMax operator";
742
743  let description = [{
744    Returns the index with the largest value across dimensions of a tensor.
745  }];
746
747  let arguments = (
748    ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input,
749    TFL_I32OrI64Tensor:$dim
750  );
751
752  let results = (outs
753    TFL_I32OrI64Tensor:$output
754  );
755
756  let hasOptions = 1;
757
758  DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
759    return getResult().getType().cast<TensorType>().getElementType().
760        cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
761            tflite::TensorType_INT32;
762    }], [{
763      TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
764    }]>;
765}
766
767def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
768  let summary = "ArgMin operator";
769
770  let description = [{
771    Returns the index with the smallest value across dimensions of a tensor.
772      a = [1, 10, 26.9, 2.8, 166.32, 62.3]
773      b = tf.math.argmin(input = a)
774      c = tf.keras.backend.eval(b)
775  }];
776
777  let arguments = (
778    ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input,
779    TFL_I32OrI64Tensor:$dim
780  );
781
782  let results = (outs
783    TFL_I32OrI64Tensor:$output
784  );
785
786  let hasOptions = 1;
787
788  DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
789    return getResult().getType().cast<TensorType>().getElementType().
790        cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
791            tflite::TensorType_INT32;
792    }], [{
793      TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
794    }]>;
795}
796
797def TFL_CeilOp: TFL_Op<"ceil", [
798    NoSideEffect,
799    SameOperandsAndResultShape,
800    SameOperandsAndResultType,
801    NoQuantizableResult]> {
802  let summary = "Ceil operator";
803
804  let description = [{
805    Returns element-wise ceil value of the input.
806  }];
807
808  let arguments = (ins TFL_FpTensor:$x);
809
810  let results = (outs TFL_FpTensor:$y);
811}
812
813def TFL_ConcatenationOp : TFL_Op<"concatenation",
814  [
815    NoSideEffect,
816    TFL_SameFirstOperandAndFirstResultElementType,
817    SameOperandsAndResultsScale
818  ]> {
819  let summary = "Concatenation operator";
820
821  let description = [{
822    Concatenates tensors along one dimension
823  }];
824
825  let arguments = (
826    ins TFL_VariadicTensorOf<
827      [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values,
828    I32Attr:$axis,
829    TFL_AFAttr:$fused_activation_function
830  );
831
832  let results = (outs
833    TFL_TensorOf<
834      [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output
835  );
836
837  let hasOptions = 1;
838
839  let hasFolder = 1;
840
841  let verifier = [{ return Verify(*this); }];
842
843  let extraClassDeclaration = [{
844    // SameScalesOpInterface:
845    bool RequiredSameOperandsAndResultsScale(bool sign, int bit_width) {
846      // uint8 doesn't require same operands and results scales.
847      bool is_uint8 = !sign && (bit_width == 8);
848      return !is_uint8;
849    }
850  }];
851}
852
853def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [ConstantLike, NoSideEffect,
854    FirstAttrDerivedResultType]> {
855  let summary = "Constant pseudo op.";
856
857  let description = [{
858    Represents a constant value in TensorFlow Lite dialect. This is not an
859    actual operation and it will be lowered to buffer instead.
860
861    The op is allowed to have all the same type of attributes as tf.Const does
862    (e.g., opaque TF attributes are allowed).
863  }];
864
865  let arguments = (ins ElementsAttr:$value);
866
867  let results = (outs AnyTensor:$output);
868
869  let hasFolder = 1;
870  let hasCanonicalizer = 1;
871
872  let builders = [
873    OpBuilder<(ins "Attribute":$value),
874    [{
875      $_state.addAttribute("value", value);
876      $_state.addTypes(value.getType());
877    }]>
878  ];
879}
880
881def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [
882    NoSideEffect,
883    FirstAttrDerivedResultType]> {
884  let summary = "Sparse constant pseudo op.";
885
886  let description = [{
887    Represents a sparse constant value in TensorFlow Lite dialect. This is not
888    an actual operation and it will be lowered to buffer instead.
889  }];
890
891  let arguments = (ins ElementsAttr:$value,
892                   SparsityParameterAttr:$s_param,
893                   ElementsAttr:$compressed_data);
894
895  let results = (outs AnyTensor:$output);
896
897  let builders = [
898    OpBuilder<(ins "Attribute":$value, "SparsityParameterAttr":$s_param,
899      "Attribute":$compressed_data),
900    [{
901      $_state.addTypes(value.getType());
902      $_state.addAttribute("value", value);
903      $_state.addAttribute("s_param", s_param);
904      $_state.addAttribute("compressed_data", compressed_data);
905    }]>
906  ];
907}
908
909def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
910  let summary = "External const op.";
911
912  let description = [{
913    External const op holds a `buffer_index` which points to a constant
914    in the flatbuffer.
915  }];
916
917  let arguments = (ins I32Attr:$buffer_index);
918
919  let results = (outs AnyTensor:$output);
920}
921
922def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0,
923      [DeclareOpInterfaceMethods<InferTypeOpInterface>,
924       DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
925  let hasCanonicalizer = 1;
926
927  let extraClassDeclaration = [{
928    // AffineQuantizedOpInterface:
929    int GetChannelDimIndex() { return 0; }
930    int GetQuantizationDimIndex() { return 0; }
931    // SparseOpInterface:
932    std::vector<int> GetSparseOperands() { return {1}; }
933    std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
934    std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
935
936    // Returns whether the return types are compatible.
937    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
938  }];
939}
940
941def TFL_CosOp: TFL_Op<"cos", [
942    NoSideEffect,
943    SameOperandsAndResultShape,
944    SameOperandsAndResultType,
945    NoQuantizableResult]> {
946  let summary = "Cosine operator";
947
948  let description = [{
949    Computes element-wise Cosine of input
950  }];
951
952  let arguments = (ins TFL_FpTensor:$x);
953
954  let results = (outs TFL_FpTensor:$y);
955
956  let hasFolder = 1;
957}
958
959def TFL_CumsumOp: TFL_Op<"cumsum", [
960    NoSideEffect,
961    PredOpTrait<"input and output must have same element type",
962      TFL_TCresVTEtIsSameAsOp<0, 0>>,
963    NoQuantizableResult,
964    TFL_OperandHasRank<1, 0>]> {
965  let summary = "Cumsum operator";
966
967  let description = [{
968    Compute the cumulative sum of the tensor x along axis.
969  }];
970
971  let arguments = (
972    ins TFL_TensorOf<[F32, I32, I64]>:$input,
973    TFL_I32Tensor:$axis,
974    DefaultValuedAttr<BoolAttr, "false">:$exclusive,
975    DefaultValuedAttr<BoolAttr, "false">:$reverse
976  );
977
978  let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output);
979
980  let hasOptions = 1;
981}
982
983def TFL_DepthwiseConv2DOp :
984    TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3,
985                [DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
986  let arguments = (
987    ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
988    TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
989    TFL_1DTensorOfOrNone<[F32, I32, I64]>:$bias,
990    I32Attr:$dilation_h_factor,
991    I32Attr:$dilation_w_factor,
992    TFL_AFAttr:$fused_activation_function,
993    TFL_PaddingAttr:$padding,
994    I32Attr:$stride_h,
995    I32Attr:$stride_w,
996    I32Attr:$depth_multiplier
997  );
998
999  let hasCanonicalizer = 1;
1000
1001  let extraClassDeclaration = [{
1002    // AffineQuantizedOpInterface:
1003    int GetChannelDimIndex() { return 3; }
1004    int GetQuantizationDimIndex() { return 3; }
1005    // SparseOpInterface:
1006    std::vector<int> GetSparseOperands() { return {1}; }
1007    std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
1008    std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
1009  }];
1010}
1011
1012def TFL_FCWO_Default  : StrEnumAttrCase<"DEFAULT">;
1013def TFL_FCWO_Shuffled4x16i8  : StrEnumAttrCase<"SHUFFLED4x16INT8">;
1014
1015def TFL_FullyConnectedOptionsWeightFormatAttr :
1016    StrEnumAttr<"FullyConnectedOptionsWeightsFormat",
1017                "fully connected options weights format", [
1018      TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
1019    ]>;
1020
1021// TODO(jpienaar): Update post discussion on semantics of FC OP.
1022def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
1023    NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
1024    AffineQuantizedOpInterface,
1025    AffineOpCoefficient<-1, 1>,
1026    TFL_SparseOp,
1027    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
1028  let summary = "Fully connected op";
1029
1030  let arguments = (ins
1031    TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
1032    TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
1033    TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
1034
1035    TFL_AFAttr:$fused_activation_function,
1036    TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format,
1037    BoolAttr:$keep_num_dims
1038  );
1039
1040  // Depending on the weights format, this op can have one or two outputs.
1041  let results = (outs
1042    TFL_VariadicTensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$output
1043  );
1044
1045  let verifier = [{ return Verify(*this); }];
1046
1047  let hasOptions = 1;
1048
1049  let hasCanonicalizer = 1;
1050
1051  let hasFolder = 1;
1052
1053  let extraClassDeclaration = [{
1054    // AffineQuantizedOpInterface:
1055    int GetChannelDimIndex() { return 0; }
1056    int GetQuantizationDimIndex() { return -1; }
1057    // SparseOpInterface:
1058    std::vector<int> GetSparseOperands() { return {1}; }
1059    std::vector<std::vector<int>> GetFloatBlockSize() { return {{1, 4}}; }
1060    std::vector<std::vector<int>> GetQuantizedBlockSize() { return {{1, 16}}; }
1061  }];
1062}
1063
1064def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
1065   NoSideEffect,
1066   TFL_OperandHasAtleastRank<0, 2>,
1067   TFL_OperandHasAtleastRank<1, 2>,
1068   PredOpTrait<"x and output must have same element type or they are int8 and int32",
1069       Or<[TFL_TCresVTEtIsSameAsOp<0, 0>,
1070           And<[CPred<"getElementTypeOrSelf($_op.getOperand(0)).isInteger(8)">,
1071                CPred<"getElementTypeOrSelf($_op.getOperand(1)).isInteger(8)">,
1072                CPred<"getElementTypeOrSelf($_op.getResult(0)).isInteger(32)">]>]>>]> {
1073
1074  let summary = "Batch Matrix Multiply Operator";
1075
1076  let description = [{
1077Performs a batched matrix multiplication on the inputs. Follows the
1078conventions of TensorFlow BatchMatMulV2, with support for unknown dimensions
1079in the batch dimensions and broadcasting.
1080
1081    Inputs:
1082      `inputs[0]`: required: input LHS
1083      `inputs[1]`: required: input RHS
1084      `adjoint_lhs`: optional: Transpose LHS (default false)
1085      `adjoint_lhs`: optional: Transpose LHS (default false)
1086  }];
1087
1088  let arguments = (ins
1089    TFL_TensorOf<[F32, QI8, QI16, I8]>:$x,
1090    TFL_TensorOf<[F32, QI8, QI16, I8]>:$y,
1091    DefaultValuedAttr<BoolAttr, "false">:$adj_x,
1092    DefaultValuedAttr<BoolAttr, "false">:$adj_y
1093  );
1094
1095   let results = (outs
1096    TFL_TensorOf<[F32, QI8, QI16, I32]>:$output
1097  );
1098
1099  let hasOptions = 1;
1100}
1101
1102def TFL_GatherOp : TFL_Op<"gather", [
1103    NoSideEffect,
1104    SameOperandsAndResultsScale,
1105    TFL_OperandHasAtleastRank<0, 1>,
1106    PredOpTrait<"params and output must have same element type",
1107      TFL_TCresVTEtIsSameAsOp<0, 0>>
1108  ]> {
1109  let summary = "Gather operator";
1110
1111  let description = [{
1112    Gather slices from `params` axis `axis` according to `indices`.
1113  }];
1114
1115  let arguments = (ins
1116    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
1117    TFL_TensorOf<[I32, I64]>:$indices,
1118    I32Attr:$axis,
1119    DefaultValuedAttr<I32Attr, "0">:$batch_dims
1120  );
1121
1122  let builders =
1123  [
1124    OpBuilder<(ins "Value":$params, "Value":$indices, "IntegerAttr":$axis, "IntegerAttr":$batch_dims),
1125    [{ BuildGatherOp(&$_builder, $_state, params, indices, axis, batch_dims); }]>
1126  ];
1127
1128  let results = (outs
1129    TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
1130  );
1131
1132  let hasOptions = 1;
1133}
1134
1135def TFL_GatherNdOp : TFL_Op<"gather_nd", [
1136    NoSideEffect,
1137    SameOperandsAndResultsScale,
1138    PredOpTrait<"params and output must have same element type",
1139      TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
1140  let summary = "Gather_nd operator";
1141
1142  let description = [{
1143    Gather slices from `params` into a Tensor with shape specified by `indices`.
1144  }];
1145
1146  let arguments = (ins
1147    TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$params,
1148    TFL_I32OrI64Tensor:$indices
1149  );
1150
1151  let results = (outs
1152    TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$output
1153  );
1154}
1155
1156def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
1157    NoSideEffect,
1158    SameOperandsAndResultsScale,
1159    TFL_OperandHasAtleastRank<0, 1>,
1160    TFL_OperandHasAtleastRank<1, 1>,
1161    PredOpTrait<"updates and output must have same element type",
1162      TFL_TCresVTEtIsSameAsOp<0, 1>>
1163  ]> {
1164  let summary = "Scatter_nd operator";
1165
1166  let description = [{
1167    Scatter `updates` into a new tensor according to `indices`
1168  }];
1169
1170  let arguments = (ins
1171    TFL_TensorOf<[I32]>:$indices,
1172    TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$updates,
1173    TFL_1DTensorOf<[I32]>:$shape
1174  );
1175
1176  let results = (outs
1177    TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$output
1178  );
1179
1180  let verifier = [{ return Verify(*this); }];
1181
1182  let hasOptions = 1;
1183}
1184
1185// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
1186def TFL_LessEqualOp : TFL_Op<"less_equal", [
1187    ResultsBroadcastableShape,
1188    ComparisonOpSameElementTypeConstraint,
1189    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1190    NoSideEffect]> {
1191  let summary = "Less_equal operator";
1192
1193  let description = [{
1194    Element-wise less_equal operation.
1195  }];
1196
1197  let arguments = (
1198      ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs,
1199      TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs);
1200
1201  let results = (outs TFL_BoolTensor:$output);
1202
1203  let builders = [TFL_ComparisonBinaryBuilder];
1204
1205  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1206
1207  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1208
1209  let hasOptions = 0;
1210}
1211
1212def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [
1213    TFL_OperandHasRank<0, 4>,
1214    SameOperandsAndResultShape,
1215    SameOperandsAndResultType,
1216    NoSideEffect,
1217    NoQuantizableResult]> {
1218  let summary = "Local Response Normalization.";
1219
1220  let description = [{
1221The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
1222dimension), and each vector is normalized independently.  Within a given vector,
1223each component is divided by the weighted, squared sum of inputs within
1224`depth_radius`.  In detail,
1225
1226    sqr_sum[a, b, c, d] =
1227        sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
1228    output = input / (bias + alpha * sqr_sum) ** beta
1229
1230For details, see [Krizhevsky et al., ImageNet classification with deep
1231convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
1232  }];
1233
1234  let arguments = (ins
1235      TFL_FpTensor:$input,
1236      I32Attr:$radius,
1237      F32Attr:$bias,
1238      F32Attr:$alpha,
1239      F32Attr:$beta
1240  );
1241
1242  let results = (outs
1243    TFL_FpTensor:$output
1244  );
1245
1246  let hasOptions = 1;
1247}
1248
1249def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
1250    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1251    ResultsBroadcastableShape,
1252    ComparisonOpSameElementTypeConstraint,
1253    NoSideEffect]> {
1254  let summary = "Greater_equal operator";
1255
1256  let description = [{
1257    Element-wise greater_equal operation.
1258  }];
1259
1260  let arguments = (
1261      ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8]>:$lhs,
1262      TFL_TensorOf<[F32, I32, I64, QUI8, QI8]>:$rhs);
1263
1264  let results = (outs TFL_BoolTensor:$output);
1265
1266  let builders = [TFL_ComparisonBinaryBuilder];
1267
1268  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1269
1270  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1271
1272  let hasOptions = 0;
1273}
1274
1275def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
1276  NoSideEffect,
1277  TFL_OperandHasAtleastRank<0, 1>,
1278  PredOpTrait<"operand and result must have the same element type",
1279    TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
1280  let summary = [{
1281    Returns a tensor with the provided diagonal and everything else padded with zeros.
1282  }];
1283
1284  let description = [{
1285    Given a diagonal, returns a tensor with the diagonal and everything else padded with zeros.
1286    Assume diagonal has k dimensions `[I, J, K, ..., N]`, then the output is a tensor of rank `k+1`
1287    with dimensions `[I, J, K, ..., N, N]` where:
1288       `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n].`
1289  }];
1290
1291  let arguments = (ins
1292    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$diagonal
1293  );
1294
1295  let results = (outs
1296    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$output
1297  );
1298
1299  let hasOptions = 0;
1300}
1301
1302def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [
1303    TFL_OperandHasAtleastRank<0, 2>,
1304    PredOpTrait<"input and result must have the same element type",
1305      TFL_TCresVTEtIsSameAsOp<0, 0>>,
1306    NoSideEffect]> {
1307  let summary = [{
1308    Returns a batched matrix tensor with new batched diagonal values.
1309  }];
1310
1311  let description = [{
1312Given `input` and `diagonal`, this operation returns a tensor with the
1313same shape and values as `input`, except for the main diagonal of the
1314innermost matrices.  These will be overwritten by the values in `diagonal`.
1315  }];
1316
1317  let arguments = (ins
1318    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
1319    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
1320  );
1321
1322  let results = (outs
1323    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
1324  );
1325
1326  let hasOptions = 0;
1327}
1328
1329// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be
1330// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the
1331// sense that one is an incremental change over the other.
1332// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and
1333// NonMaxSuppressionV4 performs hard NMS.
1334
1335def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [
1336  NoSideEffect,
1337  // Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
1338  TFL_OperandHasRank<0, 2>,
1339  PredOpTrait<"boxes should have dim[1] == 4",
1340      TFL_OperandDimEquals<0, 1, 4>>,
1341  // Operand 1 (scores) should be a 1-dim tensor
1342  TFL_OperandHasRank<1, 1>,
1343  // Other operands are scalar params.
1344  TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
1345  TFL_OperandHasRank<4, 0>,
1346  NoQuantizableResult]> {
1347  let summary = [{
1348Greedily selects a subset of bounding boxes in descending order of score,
1349  }];
1350
1351  let description = [{
1352pruning away boxes that have high intersection-over-union (IOU) overlap
1353with previously selected boxes.  Bounding boxes with score less than
1354`score_threshold` are removed.  Bounding boxes are supplied as
1355[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
1356diagonal pair of box corners and the coordinates can be provided as normalized
1357(i.e., lying in the interval [0, 1]) or absolute.  Note that this algorithm
1358is agnostic to where the origin is in the coordinate system and more
1359generally is invariant to orthogonal transformations and translations
1360of the coordinate system; thus translating or reflections of the coordinate
1361system result in the same boxes being selected by the algorithm.
1362The output of this operation is a set of integers indexing into the input
1363collection of bounding boxes representing the selected boxes.  The bounding
1364box coordinates corresponding to the selected indices can then be obtained
1365using the `tf.gather operation`.  For example:
1366  selected_indices = tf.image.non_max_suppression_v2(
1367      boxes, scores, max_output_size, iou_threshold, score_threshold)
1368  selected_boxes = tf.gather(boxes, selected_indices)
1369  }];
1370
1371  let arguments = (ins
1372    TFL_FpTensor:$boxes,
1373    TFL_FpTensor:$scores,
1374    TFL_I32Tensor:$max_output_size,
1375    TFL_FpTensor:$iou_threshold,
1376    TFL_FpTensor:$score_threshold
1377  );
1378
1379  let results = (outs
1380    TFL_I32Tensor:$selected_indices,
1381    TFL_I32Tensor:$valid_outputs
1382  );
1383}
1384
1385def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [
1386  NoSideEffect,
1387  // Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
1388  TFL_OperandHasRank<0, 2>,
1389  PredOpTrait<"boxes should have dim[1] == 4",
1390      TFL_OperandDimEquals<0, 1, 4>>,
1391  // Operand 1 (scores) should be a 1-dim tensor
1392  TFL_OperandHasRank<1, 1>,
1393  // Other operands are scalar params.
1394  TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
1395  TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>,
1396  NoQuantizableResult]> {
1397  let summary = [{
1398Greedily selects a subset of bounding boxes in descending order of score,
1399  }];
1400
1401  let description = [{
1402pruning away boxes that have high intersection-over-union (IOU) overlap
1403with previously selected boxes.  Bounding boxes with score less than
1404`score_threshold` are removed.  Bounding boxes are supplied as
1405[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
1406diagonal pair of box corners and the coordinates can be provided as normalized
1407(i.e., lying in the interval [0, 1]) or absolute.  Note that this algorithm
1408is agnostic to where the origin is in the coordinate system and more
1409generally is invariant to orthogonal transformations and translations
1410of the coordinate system; thus translating or reflections of the coordinate
1411system result in the same boxes being selected by the algorithm.
1412The output of this operation is a set of integers indexing into the input
1413collection of bounding boxes representing the selected boxes.  The bounding
1414box coordinates corresponding to the selected indices can then be obtained
1415using the `tf.gather operation`.  For example:
1416  selected_indices = tf.image.non_max_suppression_v2(
1417      boxes, scores, max_output_size, iou_threshold, score_threshold)
1418  selected_boxes = tf.gather(boxes, selected_indices)
1419This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
1420Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
1421of other overlapping boxes instead of directly causing them to be pruned.
1422To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
1423larger than 0.
1424  }];
1425
1426  let arguments = (ins
1427    TFL_FpTensor:$boxes,
1428    TFL_FpTensor:$scores,
1429    TFL_I32Tensor:$max_output_size,
1430    TFL_FpTensor:$iou_threshold,
1431    TFL_FpTensor:$score_threshold,
1432    TFL_FpTensor:$soft_nms_sigma
1433  );
1434
1435  let results = (outs
1436    TFL_I32Tensor:$selected_indices,
1437    TFL_FpTensor:$selected_scores,
1438    TFL_I32Tensor:$valid_outputs
1439  );
1440}
1441
1442def TFL_NotEqualOp : TFL_Op<"not_equal", [
1443    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1444    ComparisonOpSameElementTypeConstraint,
1445    ResultsBroadcastableShape,
1446    Commutative,
1447    NoSideEffect,
1448    NoQuantizableResult]> {
1449  let summary = "Not_equal operator";
1450
1451  let description = [{
1452    Element-wise not_equal operation.
1453  }];
1454
1455  let arguments = (
1456      ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
1457      TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
1458
1459  let results = (outs TFL_BoolTensor:$output);
1460
1461  let builders =
1462  [
1463    OpBuilder<(ins "Value":$lhs, "Value":$rhs),
1464    [{
1465        buildComparisonBinOp(&$_builder, $_state, lhs, rhs);
1466      }]>
1467  ];
1468
1469  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1470
1471  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1472}
1473
1474def TFL_DivOp : TFL_Op<"div", [
1475    // TODO(fengliuai): NoQuantizableResult is only correct for int8
1476    // quantization. update to handle Uint8 quantization.
1477    BinaryOpSameElementTypeConstraint,
1478    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
1479    ResultsBroadcastableShape,
1480    NoSideEffect,
1481    NoQuantizableResult,
1482    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
1483  let summary = "Division operator";
1484
1485  let description = [{
1486    Element-wise division operation.
1487  }];
1488
1489  let arguments = (
1490      ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs,
1491      TFL_TensorOf<[F32, I32, QUI8]>:$rhs,
1492      TFL_AFAttr:$fused_activation_function);
1493
1494  let results = (outs TFL_TensorOf<[F32, I32, QUI8]>:$output);
1495
1496  let builders = [TFL_FusedBroadcastableBinaryBuilder];
1497
1498  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1499
1500  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1501
1502  let hasOptions = 1;
1503
1504  let hasFolder = 1;
1505}
1506
1507def TFL_EluOp: TFL_Op<"elu", [
1508    NoSideEffect,
1509    SameOperandsAndResultShape,
1510    TFL_SameFirstOperandAndFirstResultElementType]> {
1511  let summary = "Exponential Linear Unit operator";
1512  let description = [{
1513    Computes the exponential linear
1514      f(x) -> exp(x) - 1 for x < 0, x for x >= 0.
1515    element-wise.
1516  }];
1517
1518  let arguments = (ins TFL_TensorOf<[F32, I8]>:$x);
1519
1520  let results = (outs TFL_TensorOf<[F32, I8]>:$y);
1521
1522  let hasOptions = 0;
1523}
1524
1525def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
1526    [NoSideEffect,
1527     PredOpTrait<"value and output must have same element type",
1528       TFL_TCresVTEtIsSameAsOp<0, 1>>,
1529     TFL_OperandHasRank<0, 1>,
1530     TFL_OperandHasRankAtLeast<1, 2>
1531    ]> {
1532  let summary = "Embedding lookup operator";
1533
1534  let description = [{
1535    Looks up ids in a list of embedding tensors.
1536  }];
1537
1538  let arguments = (ins
1539    TFL_TensorOf<[I32]>:$lookup,
1540    TFL_TensorOf<[F32, I8, UI8]>:$value
1541   );
1542
1543  let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
1544}
1545
1546def TFL_EqualOp: TFL_Op<"equal", [
1547    Commutative,
1548    NoSideEffect,
1549    ResultsBroadcastableShape,
1550    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1551    ComparisonOpSameElementTypeConstraint]> {
1552  let summary = "Equal operator";
1553
1554  let description = [{
1555    Returns the truth element of x == y element-wise
1556  }];
1557
1558  let arguments = (
1559    ins
1560    TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x,
1561    TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y
1562  );
1563
1564  let results = (outs TFL_BoolTensor:$output);
1565
1566  let builders = [TFL_ComparisonBinaryBuilder];
1567}
1568
1569def TFL_ExpOp: TFL_Op<"exp", [
1570    NoSideEffect,
1571    SameOperandsAndResultType,
1572    NoQuantizableResult]> {
1573  let summary = "Natural exponentiation operator";
1574
1575  let description = [{
1576    Performs element-wise natural exponentiation operation on input.
1577  }];
1578
1579  let arguments = (ins TFL_FpTensor:$x);
1580
1581  let results = (outs TFL_FpTensor:$y);
1582
1583  let hasOptions = 0b1;
1584}
1585
1586def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
1587    NoSideEffect,
1588    SameOperandsAndResultsScale,
1589    PredOpTrait<"input and output must have same element type",
1590      TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
1591  let summary = "Inserts a dimension of 1 into a tensor's shape.";
1592
1593  let description = [{
1594Given a tensor `input`, this operation inserts a dimension of 1 at the
1595dimension index `axis` of `input`'s shape. The dimension index `axis` starts at
1596zero; if you specify a negative number for `axis` it is counted backward from
1597the end.
1598
1599This operation is useful if you want to add a batch dimension to a single
1600element. For example, if you have a single image of shape `[height, width,
1601channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
1602which will make the shape `[1, height, width, channels]`.
1603
1604Other examples:
1605
1606```
1607# 't' is a tensor of shape [2]
1608shape(expand_dims(t, 0)) ==> [1, 2]
1609shape(expand_dims(t, 1)) ==> [2, 1]
1610shape(expand_dims(t, -1)) ==> [2, 1]
1611
1612# 't2' is a tensor of shape [2, 3, 5]
1613shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
1614shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
1615shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
1616```
1617
1618This operation requires that:
1619
1620`-1-input.dims() <= dim <= input.dims()`
1621
1622This operation is related to `squeeze()`, which removes dimensions of
1623size 1.
1624  }];
1625
1626  // TODO: Restriction on dim's size and valid range are not modeled here.
1627  let arguments = (ins AnyTensor:$input, TFL_I32OrI64Tensor:$dim);
1628
1629  let results = (outs AnyTensor:$output);
1630
1631  let hasOptions = 1;
1632}
1633
1634def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect,
1635                                      SameOperandsAndResultsScale]> {
1636  let summary = "Removes dimensions of size 1 from the shape of a tensor.";
1637
1638  let description = [{
1639Given a tensor `input`, this operation returns a tensor of the same type with
1640all dimensions of size 1 removed. If you don't want to remove all size 1
1641dimensions, you can remove specific size 1 dimensions by specifying
1642`squeeze_dims`.
1643
1644For example:
1645
1646```
1647# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
1648shape(squeeze(t)) ==> [2, 3]
1649```
1650
1651Or, to remove specific size 1 dimensions:
1652
1653```
1654# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
1655shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
1656```
1657  }];
1658
1659  let arguments = (ins
1660    AnyTensor:$input,
1661    Confined<DefaultValuedAttr<I64ArrayAttr, "{}">, [TFL_ArrayMaxCount<8>]>:$squeeze_dims
1662  );
1663
1664  let results = (outs
1665    AnyTensor:$output
1666  );
1667
1668  let hasFolder = 1;
1669  let hasOptions = 1;
1670
1671  let customOption = "SqueezeOptions";
1672}
1673
1674def TFL_FillOp: TFL_Op<"fill", [
1675    NoSideEffect,
1676    PredOpTrait<"input and result must have same element type",
1677      TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
1678  let summary = "Fill the tensor with given value.";
1679  let description = [{
1680    Fill the tensor with given value.
1681  }];
1682
1683  let arguments = (ins TFL_I32OrI64Tensor:$dims,
1684                   TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$input);
1685
1686  let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$result);
1687
1688  let hasOptions = 0;
1689}
1690
1691def TFL_FloorOp: TFL_Op<"floor", [
1692    NoSideEffect,
1693    SameOperandsAndResultShape,
1694    SameOperandsAndResultType,
1695    NoQuantizableResult]> {
1696  let summary = "Floor operator";
1697
1698  let description = [{
1699    Returns element-wise floor value of the input.
1700  }];
1701
1702  let arguments = (ins TFL_FpTensor:$x);
1703
1704  let results = (outs TFL_FpTensor:$y);
1705}
1706
1707def TFL_FloorDivOp : TFL_Op<"floor_div", [
1708    ResultsBroadcastableShape,
1709    NoSideEffect,
1710    BinaryOpSameElementTypeConstraint,
1711    PredOpTrait<"lhs and output must have same element type",
1712      TFL_TCresVTEtIsSameAsOp<0, 0>>,
1713    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1714    NoQuantizableResult]> {
1715  let summary = "Floor div operator";
1716
1717  let description = [{
1718    Element-wise floor div operation.
1719  }];
1720
1721  let arguments = (
1722    ins TFL_TensorOf<[F32, I32]>:$lhs, TFL_TensorOf<[F32, I32]>:$rhs);
1723
1724  let results = (outs TFL_TensorOf<[F32, I32]>:$output);
1725
1726  let builders = [TFL_BroadcastableBinaryBuilder];
1727
1728  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1729
1730  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1731}
1732
1733def TFL_FloorModOp : TFL_Op<"floor_mod", [
1734    ResultsBroadcastableShape,
1735    NoSideEffect,
1736    BinaryOpSameElementTypeConstraint,
1737    PredOpTrait<"lhs and output must have same element type",
1738      TFL_TCresVTEtIsSameAsOp<0, 0>>,
1739    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1740    NoQuantizableResult]> {
1741  let summary = "Division reminder";
1742
1743  let description = [{
1744    Element-wise division reminder operation.
1745  }];
1746
1747  let arguments = (
1748    ins TFL_TensorOf<[I32, I64, F32]>:$lhs,
1749    TFL_TensorOf<[I32, I64, F32]>:$rhs);
1750
1751  let results = (outs TFL_TensorOf<[I32, I64, F32]>:$output);
1752
1753  let builders = [TFL_BroadcastableBinaryBuilder];
1754}
1755
1756def TFL_GreaterOp : TFL_Op<"greater", [
1757    ResultsBroadcastableShape,
1758    ComparisonOpSameElementTypeConstraint,
1759    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1760    NoSideEffect]> {
1761  let summary = "Greater operator";
1762
1763  let description = [{
1764    Element-wise greater operation.
1765  }];
1766
1767  let arguments = (
1768    ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
1769    TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
1770
1771  let results = (outs TFL_BoolTensor:$output);
1772
1773  let builders = [TFL_ComparisonBinaryBuilder];
1774
1775  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1776
1777  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1778}
1779
1780def TFL_HardSwishOp: TFL_Op<"hard_swish", [
1781    NoSideEffect,
1782    SameOperandsAndResultShape,
1783    PredOpTrait<"input and output must have same element type",
1784      TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
1785  let summary = "Hardswish activation function.";
1786  let description = [{
1787    Computes hard-swish activation function
1788      f(x) -> (x * relu6(x+3))/6
1789    element-wise.
1790  }];
1791
1792  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input);
1793
1794  let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output);
1795
1796  let hasOptions = 0;
1797}
1798
1799def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
1800    FixedOutputRangeInterface,
1801    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
1802  let summary = "L2 Normalize Operator";
1803
1804  let description = [{
1805    L2Normalization Op
1806  }];
1807
1808  let arguments = (ins
1809    TFL_TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
1810    TFL_AFAttr:$fused_activation_function
1811  );
1812
1813  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
1814
1815  let hasOptions = 1;
1816
1817  let customOption = "L2NormOptions";
1818
1819  let extraClassDeclaration = [{
1820  // FixedOutputRangeInterface:
1821  quant::UniformQuantizedType GetFixedOutputRange(
1822      bool is_signed, int bit_width) {
1823    auto result_type = output().getType();
1824    // central_value = min_value / 2 + (max_value - 1) / 2 + 1
1825    // zero_point = central_value
1826    // scale = 1. / (central_value - min_value)
1827    return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
1828        /*scale=*/1.0 / 128, /*zero_point=*/0);
1829  }
1830  }];
1831}
1832
1833def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
1834    SameOperandsAndResultShape,
1835    NoSideEffect,
1836    PredOpTrait<"input and output must have same element type",
1837      TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
1838  let summary = "Leaky Relu operator";
1839
1840  let description = [{
1841    Element-wise Leaky ReLU operator
1842      x -> x >= 0 ? x : (alpha * x)
1843  }];
1844
1845  let arguments = (
1846    ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
1847    // Slope of the activation function at x < 0.
1848    F32Attr:$alpha
1849  );
1850
1851  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
1852
1853  let hasOptions = 0b1;
1854}
1855
1856def TFL_LessOp : TFL_Op<"less", [
1857    ResultsBroadcastableShape,
1858    ComparisonOpSameElementTypeConstraint,
1859    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
1860    NoSideEffect]> {
1861  let summary = "Less operator";
1862
1863  let description = [{
1864    Element-wise less operation.
1865  }];
1866
1867  let arguments = (
1868    ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
1869    TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
1870
1871  let results = (outs TFL_BoolTensor:$output);
1872
1873  let builders = [TFL_ComparisonBinaryBuilder];
1874
1875  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1876
1877  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1878}
1879
1880def TFL_LogicalAndOp : TFL_Op<"logical_and", [
1881    NoSideEffect,
1882    NoQuantizableResult]> {
1883  let summary = "Logical AND operator";
1884
1885  let description = [{
1886    Element-wise logical AND operation.
1887  }];
1888
1889  let arguments = (
1890    ins TFL_BoolTensor:$lhs,
1891    TFL_BoolTensor:$rhs);
1892
1893  let results = (outs TFL_BoolTensor:$output);
1894
1895  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1896
1897  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1898}
1899
1900def TFL_LogicalNotOp : TFL_Op<"logical_not", [
1901    NoSideEffect,
1902    SameOperandsAndResultShape,
1903    NoQuantizableResult]> {
1904  let summary = "Logical NOT operator";
1905
1906  let description = [{
1907    Element-wise logical NOT operation.
1908  }];
1909
1910  let arguments = (ins TFL_BoolTensor:$lhs);
1911
1912  let results = (outs TFL_BoolTensor:$output);
1913}
1914
1915def TFL_LogicalOrOp : TFL_Op<"logical_or", [
1916    NoSideEffect,
1917    NoQuantizableResult]> {
1918  let summary = "Logical OR operator";
1919
1920  let description = [{
1921    Element-wise logical OR operation.
1922  }];
1923
1924  let arguments = (
1925    ins TFL_BoolTensor:$lhs,
1926    TFL_BoolTensor:$rhs);
1927
1928  let results = (outs TFL_BoolTensor:$output);
1929
1930  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
1931
1932  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
1933}
1934
1935def TFL_LogisticOp: TFL_Op<"logistic", [
1936    NoSideEffect,
1937    PredOpTrait<"x and y must have same element type",
1938      TFL_TCresVTEtIsSameAsOp<0, 0>>,
1939    SameOperandsAndResultShape,
1940    FixedOutputRangeInterface,
1941    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
1942  let summary = "Logistic operator";
1943
1944  let description = [{
1945    Computes element-wise Sigmoid of input
1946  }];
1947
1948  let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
1949
1950  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
1951
1952  let extraClassDeclaration = [{
1953  // FixedOutputRangeInterface:
1954  quant::UniformQuantizedType GetFixedOutputRange(
1955      bool is_signed, int bit_width) {
1956    auto result_type = y().getType();
1957    // zero_point = 0
1958    // scale = 1. / (max_value + 1)
1959    return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
1960        /*scale=*/1.0 / 256, /*zero_point=*/-128);
1961  }
1962  }];
1963
1964  // This builder doesn't work with quantized type, so it can only be used by
1965  // non-quantization tablegen patterns. Currently, it is used by the
1966  // elementwise-move reordering pattern in the optimize_patterns.td
1967  let builders = [
1968    OpBuilder<(ins "Value":$input),
1969    [{
1970      $_state.addOperands({input});
1971      $_state.addTypes(input.getType());
1972    }]>
1973  ];
1974}
1975
1976def TFL_LogOp: TFL_Op<"log", [
1977    NoSideEffect,
1978    SameOperandsAndResultShape,
1979    SameOperandsAndResultType,
1980    NoQuantizableResult]> {
1981  let summary = "Natural logarithm operator";
1982
1983  let description = [{
1984    Performs element-wise natural logarithm operation on input.
1985  }];
1986
1987  let arguments = (ins TFL_FpTensor:$x);
1988
1989  let results = (outs TFL_FpTensor:$y);
1990
1991  let hasFolder = 1;
1992}
1993
1994def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
1995    NoSideEffect,
1996    SameOperandsAndResultShape,
1997    PredOpTrait<"x and y must have same element type",
1998      TFL_TCresVTEtIsSameAsOp<0, 0>>,
1999    FixedOutputRangeInterface,
2000    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
2001  let summary = "Log softmax operator";
2002
2003  let description = [{
2004    Computes element-wise log softmax activations with the following formula
2005
2006      input - log(reduce_sum(exp(input), dim))
2007  }];
2008
2009  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input);
2010
2011  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
2012
2013  let hasOptions = 1;
2014
2015  let extraClassDeclaration = [{
2016  // FixedOutputRangeInterface:
2017  quant::UniformQuantizedType GetFixedOutputRange(
2018      bool is_signed, int bit_width) {
2019    auto result_type = output().getType();
2020    // zero_point = max_value
2021    // scale = -log_softmax_output_min / (max_value + 1)
2022    return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
2023        /*scale=*/16.0 / 256, /*zero_point=*/127);
2024  }
2025  }];
2026}
2027
2028// TODO(ashwinm): Revisit the granularity of the PredOpTraits. We could
2029// break this into smaller PredOpTraits, each with more descriptive messages
2030// that would make it easier to trace failures OR, need a way to specify desc
2031// per Predicate inside the trait and get tablegen to use that to emit error
2032// message.
2033def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
2034    "result types match specified constraints",
2035  And<[
2036    // The input and output tensors should have the same elemental type
2037    // and they should be one of the specified types below.
2038    TCopVTEtIs<0, AnyTypeOf<[F32, QI8, QUI8]>>,
2039    TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
2040
2041def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
2042    TFL_OperandHasRank<0, 4>,
2043    PredOpTrait<"input and output must have same element type",
2044      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2045    NoSideEffect,
2046    MaxPoolOperandAndResultConstraints,
2047    SameOperandsAndResultsScale,
2048    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
2049  let summary = "Max Pool 2D op";
2050
2051  let description = [{
2052    Performs max pool 2D on input.
2053
2054    Inputs:
2055      `inputs[0]`: required: the input tensor
2056  }];
2057
2058  let arguments = (
2059    ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input,
2060    TFL_PaddingAttr:$padding,
2061    I32Attr:$stride_w,
2062    I32Attr:$stride_h,
2063    I32Attr:$filter_width,
2064    I32Attr:$filter_height,
2065    TFL_AFAttr:$fused_activation_function
2066  );
2067
2068  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output);
2069
2070  let hasOptions = 1;
2071
2072  let customOption = "Pool2DOptions";
2073}
2074
2075def TFL_MaximumOp : TFL_Op<"maximum", [
2076    ResultsBroadcastableShape,
2077    NoSideEffect,
2078    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
2079    Commutative,
2080    SameOperandsAndResultsScale]> {
2081  let summary = "Max operator";
2082  let description = [{
2083    Element-wise max operation.
2084  }];
2085
2086  let arguments = (
2087    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
2088    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
2089  );
2090
2091  let results = (outs
2092    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
2093  );
2094
2095  let builders = [TFL_BroadcastableBinaryBuilder];
2096
2097  let hasOptions = 0;
2098}
2099
2100def TFL_MeanOp : TFL_Op<"mean", [
2101    PredOpTrait<"input and output must have same element type",
2102      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2103    NoSideEffect]> {
2104  let summary = "Mean operator";
2105
2106  let description = [{
2107    Computes the mean of elements across dimensions of a tensor.
2108    Reduces input_tensor along the dimensions given in axis.
2109    Unless keepdims is true, the rank of the tensor is reduced by 1 for
2110    each entry in axis. If keepdims is true, the reduced dimensions are retained
2111    with length 1.
2112  }];
2113
2114  let arguments = (ins
2115    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
2116    TFL_TensorOf<[I32, I64]>:$axis,
2117    BoolAttr:$keep_dims
2118  );
2119
2120  let results = (outs
2121    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
2122
2123  let hasOptions = 1;
2124  let customOption = "ReducerOptions";
2125}
2126
2127def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
2128  let summary = "OneHot operator";
2129
2130  let description = [{
2131    Returns a one-hot tensor.The locations represented by indices in `indices`
2132    take value `on_value`, while all other locations take value `off_value`.
2133
2134    If the input `indices` is rank `N`, the output will have rank `N+1`,
2135    The new axis is created at dimension `axis` (default: the new axis is
2136    appended at the end).
2137  }];
2138
2139  let arguments = (ins
2140    TFL_TensorOf<[I32, I64]>:$indices,
2141    TFL_I32Tensor:$depth,
2142    TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value,
2143    TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value,
2144
2145    I32Attr:$axis
2146  );
2147
2148  let results = (outs
2149    TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output
2150  );
2151
2152  let hasOptions = 1;
2153}
2154
2155def TFL_RoundOp: TFL_Op<"round", [
2156    NoSideEffect,
2157    SameOperandsAndResultShape,
2158    SameOperandsAndResultType,
2159    NoQuantizableResult]> {
2160  let summary = "Round operator";
2161
2162  let description = [{
2163Rounds the values of a tensor to the nearest integer, element-wise.
2164  }];
2165
2166  let arguments = (ins
2167    TFL_FpTensor:$x
2168  );
2169
2170  let results = (outs
2171    TFL_FpTensor:$y
2172  );
2173}
2174
2175def TFL_SliceOp : TFL_Op<"slice", [
2176    PredOpTrait<"input and output must have same element type",
2177      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2178    NoSideEffect,
2179    SameOperandsAndResultsScale,
2180    TFL_OperandHasRankAtMost<0, 5>,
2181    TFL_OperandHasRankAtMost<1, 1>,
2182    TFL_OperandHasRankAtMost<2, 1>]> {
2183  let summary = "Return a slice from 'input'.";
2184
2185  let description = [{
2186The output tensor is a tensor with dimensions described by 'size'
2187whose values are extracted from 'input' starting at the offsets in
2188'begin'.
2189
2190`begin` is zero-based; `size` is one-based. If size[i] is -1, all remaining
2191elements in dimension i are included in the slice. In other words, this is
2192equivalent to setting:
2193  size[i] = input.dim_size(i) - begin[i]
2194
2195*Requirements*:
2196  0 <= begin[i] <= begin[i] + size[i] <= Di  for i in [0, n)
2197  }];
2198
2199  let arguments = (ins
2200    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2201    TFL_I32OrI64Tensor:$begin,
2202    TFL_I32OrI64Tensor:$size
2203  );
2204
2205  let results = (outs
2206    TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
2207  );
2208
2209  let verifier = [{ return Verify(*this); }];
2210
2211  let hasCanonicalizer = 1;
2212}
2213
2214def TFL_SumOp: TFL_Op<"sum", [
2215    PredOpTrait<"input and output must have same element type",
2216      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2217    NoSideEffect]> {
2218
2219  let summary = "Sum operator";
2220
2221  let description = [{
2222    Computes the sum reduction along the specified axes
2223  }];
2224
2225  let arguments = (ins
2226    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2227    TFL_I32Tensor:$axes,
2228    BoolAttr:$keep_dims
2229  );
2230
2231  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2232
2233  let hasOptions = 1;
2234  let customOption = "ReducerOptions";
2235}
2236
2237def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
2238    PredOpTrait<"input and output must have same element type",
2239      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2240    NoSideEffect,
2241    SameOperandsAndResultsScale]> {
2242  let summary = "Min-reduction operator";
2243
2244  let description = [{
2245    Computes the min reduction along the specified axes
2246  }];
2247
2248  let arguments = (ins
2249    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2250    TFL_I32Tensor:$axes,
2251    BoolAttr:$keep_dims
2252  );
2253
2254  let results = (outs
2255    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2256
2257  let hasOptions = 1;
2258  let customOption = "ReducerOptions";
2259}
2260
2261def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
2262    PredOpTrait<"input and output must have same element type",
2263      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2264    NoSideEffect,
2265    SameOperandsAndResultsScale]> {
2266  let summary = "Max-reduction operator";
2267
2268  let description = [{
2269    Computes the max reduction along the specified axes
2270  }];
2271
2272  let arguments = (ins
2273    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2274    TFL_I32Tensor:$axes,
2275    BoolAttr:$keep_dims
2276  );
2277
2278  let results = (outs
2279    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2280
2281  let hasOptions = 1;
2282  let customOption = "ReducerOptions";
2283}
2284
2285def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
2286    PredOpTrait<"input and output must have same element type",
2287      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2288    NoSideEffect]> {
2289  let summary = "Prod-reduction operator";
2290
2291  let description = [{
2292    Computes the product along the specified axes
2293  }];
2294
2295  let arguments = (ins
2296    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2297    TFL_I32Tensor:$axes,
2298    BoolAttr:$keep_dims
2299  );
2300
2301  let results = (outs
2302    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2303
2304  let hasOptions = 1;
2305  let customOption = "ReducerOptions";
2306}
2307
2308def TFL_MinimumOp : TFL_Op<"minimum", [
2309    ResultsBroadcastableShape,
2310    NoSideEffect,
2311    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
2312    Commutative,
2313    SameOperandsAndResultsScale]> {
2314  let summary = "Min operator";
2315  let description = [{
2316    Element-wise min operation.
2317  }];
2318
2319  let arguments = (
2320    ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
2321    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
2322  );
2323
2324  let results = (outs
2325    TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
2326  );
2327
2328  let builders = [TFL_BroadcastableBinaryBuilder];
2329
2330  let hasOptions = 0;
2331}
2332
2333def TFL_MulOp : TFL_Op<"mul", [
2334    ResultsBroadcastableShape,
2335    NoSideEffect,
2336    Commutative,
2337    BinaryOpSameElementTypeConstraint,
2338    TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
2339      CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
2340    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
2341  let summary = "Multiplication operator";
2342
2343  let description = [{
2344    Element-wise multiplication operation.
2345  }];
2346
2347  let arguments = (
2348    ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
2349    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
2350    TFL_AFAttr:$fused_activation_function);
2351
2352  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
2353
2354  let hasFolder = 1;
2355
2356  let builders = [TFL_FusedBroadcastableBinaryBuilder];
2357
2358  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
2359
2360  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
2361
2362  let hasOptions = 1;
2363}
2364
2365def TFL_NegOp: TFL_Op<"neg", [
2366    NoSideEffect,
2367    SameOperandsAndResultShape,
2368    SameOperandsAndResultType,
2369    NoQuantizableResult]> {
2370  let summary = "Negation operator";
2371
2372  let description = [{
2373    Computes element-wise negation of input
2374  }];
2375
2376  let arguments = (ins TFL_TensorOf<[F32, I32, I64]>:$x);
2377
2378  let results = (outs TFL_TensorOf<[F32, I32, I64]>:$y);
2379
2380  let hasOptions = 0b1;
2381
2382  let hasFolder = 1;
2383}
2384
2385def TFL_PackOp : TFL_Op<"pack", [
2386    TFL_SameFirstOperandAndFirstResultElementType,
2387    NoSideEffect,
2388    SameOperandsAndResultsScale]> {
2389  let summary = "Packs a list of tensors along a dimension into one tensor";
2390
2391  let description = [{
2392    Packs a list of `values_count` rank-`R` tensors into one rank-`(R+1)`
2393    tensor.
2394
2395    Packs the `values_count` tensors in `values` into a tensor with rank one
2396    higher than each tensor in `values`, by packing them along the `axis`
2397    dimension.
2398
2399    Given a list of tensors of shape `(A, B, C)`;
2400
2401    if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
2402    if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
2403    Etc.
2404
2405    For example:
2406
2407    ```
2408    # 'x' is [1, 4]
2409    # 'y' is [2, 5]
2410    # 'z' is [3, 6]
2411    pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]]  # Pack along first dim.
2412    pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
2413    ```
2414
2415    This is the opposite of `unpack`.
2416  }];
2417
2418  let arguments = (ins
2419    TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values,
2420
2421    Confined<I32Attr, [IntPositive]>:$values_count,
2422    I32Attr:$axis
2423  );
2424
2425  let results = (outs
2426    TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output
2427  );
2428
2429  let verifier = [{ return Verify(*this); }];
2430
2431  let hasCanonicalizer = 1;
2432
2433  let hasOptions = 1;
2434}
2435
2436def TFL_PadOp : TFL_Op<"pad", [
2437    PredOpTrait<"input and output must have same element type",
2438      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2439    NoSideEffect,
2440    SameOperandsAndResultsScale,
2441    TFL_OperandHasRankAtMost<0, 5>,
2442    TFL_OperandHasRank<1, 2>,
2443    TFL_OperandRankEquals1DimOfOperand<0, 1>,
2444    PredOpTrait<"the first dim size of the padding argument must be at most 5",
2445      Or<[TFL_OperandIsUnrankedPred<1>,
2446          TFL_OperandDimIsAtMost<1, 0, 5>]>>]> {
2447  let summary = "Padding operator";
2448
2449  let description = [{
2450    This operation pads a `input` with zeros according to the `paddings` you
2451    specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
2452    the rank of `input`. For each dimension D of `input`, `paddings[D, 0]`
2453    indicates how many zeros to add before the contents of `input` in that
2454    dimension, and `paddings[D, 1]` indicates how many zeros to add after the
2455    contents of `input` in that dimension.
2456
2457    The padded size of each dimension D of the output is:
2458
2459      `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
2460
2461    For example:
2462
2463    ```
2464    # 't' is [[1, 1], [2, 2]]
2465    # 'paddings' is [[1, 1], [2, 2]]
2466    # rank of 't' is 2
2467    pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
2468                          [0, 0, 1, 1, 0, 0]
2469                          [0, 0, 2, 2, 0, 0]
2470                          [0, 0, 0, 0, 0, 0]]
2471    ```
2472  }];
2473
2474  let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2475    TFL_I32OrI64Tensor:$padding);
2476
2477  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2478
2479  let hasOptions = 1;
2480
2481  let hasFolder = 1;
2482}
2483
2484def TFL_PadV2Op : TFL_Op<"padv2", [
2485    PredOpTrait<"input and output must have same element type",
2486      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2487    NoSideEffect,
2488    SameOperandsAndResultsScale,
2489    TFL_OperandHasRankAtMost<0, 5>,
2490    TFL_OperandHasRank<1, 2>,
2491    TFL_OperandHasRank<2, 0>,
2492    TFL_OperandRankEquals1DimOfOperand<0, 1>,
2493    PredOpTrait<"the first dim size of the padding argument must be at most 5",
2494      Or<[TFL_OperandIsUnrankedPred<1>,
2495          TFL_OperandDimIsAtMost<1, 0, 5>]>>,
2496    PredOpTrait<"input and constant value operands must have same element type",
2497      TFL_TCopVTEtAreSameAt<0, 2>>]> {
2498  let summary = "Padding operator v2";
2499
2500  let description = [{
2501    This operation pads a `input` according to the `paddings` and
2502    `constant_values` you specify. `paddings` is an integer tensor with shape
2503    `[Dn, 2]`, where n is the rank of `input`. For each dimension D of `input`,
2504    `paddings[D, 0]` indicates how many zeros to add before the contents of
2505    `input` in that dimension, and `paddings[D, 1]` indicates how many zeros to
2506    add after the contents of `input` in that dimension. `constant_values` is a
2507    scalar tensor of the same type as `input` that indicates the value to use
2508    for padding `input`.
2509
2510    The padded size of each dimension D of the output is:
2511
2512      `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
2513
2514    For example:
2515
2516    ```
2517    # 't' is [[1, 1], [2, 2]]
2518    # 'paddings' is [[1, 1], [2, 2]]
2519    # rank of 't' is 2
2520    pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
2521                          [0, 0, 1, 1, 0, 0]
2522                          [0, 0, 2, 2, 0, 0]
2523                          [0, 0, 0, 0, 0, 0]]
2524    ```
2525  }];
2526
2527  let arguments = (
2528    ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input,
2529    TFL_I32OrI64Tensor:$padding,
2530    TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values);
2531
2532  let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output);
2533
2534  let hasOptions = 1;
2535
2536  let hasFolder = 1;
2537}
2538
2539def TFL_PowOp : TFL_Op<"pow", [
2540    ResultsBroadcastableShape,
2541    NoSideEffect,
2542    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
2543    NoQuantizableResult]> {
2544  let summary = "Power operator";
2545
2546  let description = [{
2547    Element-wise power operation.
2548  }];
2549
2550  let arguments = (
2551    ins TFL_TensorOf<[F32, I32]>:$lhs,
2552    TFL_TensorOf<[F32, I32]>:$rhs);
2553
2554  let results = (outs TFL_TensorOf<[F32, I32]>:$output);
2555
2556  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
2557
2558  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
2559
2560  let builders = [TFL_BroadcastableBinaryBuilder];
2561}
2562
2563def TFL_PReluOp : TFL_Op<"prelu", [
2564    NoSideEffect,
2565    ResultsBroadcastableShape,
2566    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
2567    BinaryOpSameElementTypeConstraint,
2568    PredOpTrait<"input and output must have the same element type",
2569      TFL_TCresVTEtIsSameAsOp<0, 0>>, AffineQuantizedOpInterface]> {
2570  let summary = "Parameterized Relu operator";
2571
2572  let description = [{
2573    Parameterized Relu operator
2574      x -> x >= 0 ? x : (alpha * x)
2575    where alpha is a trainable tensor.
2576    input and alpha should be the same size as input or be broadcastable.
2577  }];
2578
2579  let arguments = (
2580    ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
2581    TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha
2582  );
2583
2584  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
2585
2586  let verifier = [{ return Verify(*this); }];
2587
2588  let extraClassDeclaration = [{
2589    // AffineQuantizedOpInterface:
2590    int GetChannelDimIndex() { return 0; }
2591    int GetQuantizationDimIndex() { return -1; }
2592  }];
2593}
2594
2595def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
2596  let summary = "Rank operator.";
2597  let description = [{
2598    Returns the rank of a tensor.
2599  }];
2600
2601  let arguments = (ins AnyTensor:$input);
2602
2603  let results = (outs TFL_IntTensor:$output);
2604
2605  let hasFolder = 1;
2606}
2607
2608def TFL_ReluOp: TFL_Op<"relu", [
2609    PredOpTrait<"x and y must have same element type",
2610      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2611    NoSideEffect,
2612    SameOperandsAndResultShape]> {
2613  let summary = "Relu operator";
2614
2615  let description = [{
2616    Element-wise Relu operator
2617      x -> max(0, x)
2618  }];
2619
2620  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
2621
2622  let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
2623
2624  // This builder doesn't work with quantized type, so it can only be used by
2625  // non-quantization tablegen patterns. Currently, it is used by the
2626  // elementwise-move reordering pattern in the optimize_patterns.td
2627  let builders = [
2628    OpBuilder<(ins "Value":$input),
2629    [{
2630      $_state.addOperands({input});
2631      $_state.addTypes(input.getType());
2632    }]>
2633  ];
2634}
2635
2636def TFL_Relu6Op: TFL_Op<"relu6", [
2637    PredOpTrait<"x and y must have same element type",
2638      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2639    NoSideEffect,
2640    SameOperandsAndResultShape]> {
2641  let summary = "Relu6 operator";
2642
2643  let description = [{
2644    Element-wise Relu6 operator
2645      x -> max(0, min(6, x))
2646  }];
2647
2648  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
2649
2650  let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
2651
2652  // This builder doesn't work with quantized type, so it can only be used by
2653  // non-quantization tablegen patterns. Currently, it is used by the
2654  // elementwise-move reordering pattern in the optimize_patterns.td
2655  let builders = [
2656    OpBuilder<(ins "Value":$input),
2657    [{
2658      $_state.addOperands({input});
2659      $_state.addTypes(input.getType());
2660    }]>
2661  ];
2662}
2663
2664def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [
2665    PredOpTrait<"x and y must have same element type",
2666      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2667    NoSideEffect,
2668    SameOperandsAndResultShape]> {
2669  let summary = "Relu1 operator";
2670
2671  let description = [{
2672    Element-wise Relu1 operator
2673      x -> max(-1, min(1, x))
2674  }];
2675
2676  let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
2677
2678  let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
2679
2680  // This builder doesn't work with quantized type, so it can only be used by
2681  // non-quantization tablegen patterns. Currently, it is used by the
2682  // elementwise-move reordering pattern in the optimize_patterns.td
2683  let builders = [
2684    OpBuilder<(ins "Value":$input),
2685    [{
2686      $_state.addOperands({input});
2687      $_state.addTypes(input.getType());
2688    }]>
2689  ];
2690}
2691
2692def TFL_ReshapeOp: TFL_Op<"reshape", [
2693    NoSideEffect, SameOperandsAndResultsScale]> {
2694  let summary = "Reshape operator";
2695
2696  let description = [{
2697    Produces a tensor with the same values but different static shape defined
2698    by the output type.
2699  }];
2700
2701  let arguments = (
2702    ins AnyTensor:$input,
2703    TFL_I32Tensor:$shape);
2704
2705  let results = (outs AnyTensor:$output);
2706  let hasCanonicalizer = 0b1;
2707  let hasFolder = 1;
2708
2709  let verifier = [{ return Verify(*this); }];
2710}
2711
2712def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [
2713    PredOpTrait<"input and output must have same element type",
2714      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2715    NoSideEffect,
2716    TFL_OperandHasRank<1, 1>]> {
2717  let summary = "Reverses variable length slices.";
2718
2719  let description = [{
2720This op first slices `input` along the dimension `batch_dim`, and for each
2721slice `i`, reverses the first `seq_lengths[i]` elements along
2722the dimension `seq_dim`.
2723
2724The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`,
2725and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
2726
2727The output slice `i` along dimension `batch_dim` is then given by input
2728slice `i`, with the first `seq_lengths[i]` slices along dimension
2729`seq_dim` reversed.
2730  }];
2731
2732  let arguments = (ins
2733    TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$input,
2734    TFL_I32OrI64Tensor:$seq_lengths,
2735
2736    Confined<I32Attr, [IntNonNegative]>:$seq_dim,
2737    Confined<I32Attr, [IntNonNegative]>:$batch_dim
2738  );
2739
2740  let results = (outs
2741    TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$output
2742  );
2743
2744  let hasOptions = 1;
2745}
2746
2747def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
2748                                  TFL_SameFirstOperandAndFirstResultElementType,
2749                                  SameOperandsAndResultShape]> {
2750  let summary = "Reciprocal of square root operator";
2751
2752  let description = [{
2753    Computes element-wise reverse square root of input
2754  }];
2755
2756  let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
2757
2758  let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
2759
2760  let hasFolder = 1;
2761}
2762
2763def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
2764  let summary = "Shape operator";
2765
2766  let description = [{
2767    Returns the shape of a tensor.
2768  }];
2769
2770  let arguments = (ins AnyTensor:$input);
2771
2772  let results = (outs TFL_TensorOf<[I32, I64]>:$output);
2773
2774  DerivedTypeAttr out_type = DerivedTypeAttr<[{
2775    return getResult().getType().cast<TensorType>().getElementType();
2776  }]>;
2777
2778  let hasOptions = 1;
2779
2780  let hasFolder = 1;
2781}
2782
2783def TFL_RangeOp: TFL_Op<"range", [
2784    NoSideEffect,
2785    TFL_OperandHasRank<0, 0>,
2786    TFL_OperandHasRank<1, 0>,
2787    TFL_OperandHasRank<2, 0>,
2788    PredOpTrait<"operands and output must have same element type",
2789      And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>,
2790           TCresVTEtIsSameAsOp<0, 2>]>>,
2791    NoQuantizableResult]> {
2792  let summary = "Range operator";
2793
2794  let description = [{
2795    Returns a 1D tensor defined by a sequence from `start` to `limit` with
2796    a given `delta`.
2797  }];
2798
2799  let arguments = (ins
2800    TFL_TensorOf<[I32, F32]>:$start,
2801    TFL_TensorOf<[I32, F32]>:$limit,
2802    TFL_TensorOf<[I32, F32]>:$delta);
2803
2804  let results = (outs TFL_TensorOf<[I32, F32]>:$result);
2805
2806  let hasFolder = 1;
2807}
2808
2809def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
2810    PredOpTrait<"input and output must have same element type",
2811      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2812    NoSideEffect,
2813    TFL_OperandHasRank<1, 1>]> {
2814  let summary = "ReverseV2 Operator";
2815
2816  let description = [{
2817    Reverses specific dimensions of a tensor.
2818
2819    Given a tensor, and a int32/int64 tensor axis representing the set
2820    of dimensions of tensor to reverse.
2821    This operation reverses each dimension i for
2822    which there exists j s.t. axis[j] == i.
2823
2824    Args:
2825      tensor: A Tensor. Must be one of the following types:
2826      uint8, int8, int16, int32, int64, float32, bool Up to 8-D.
2827
2828      axis: A Tensor. Must be one of the following types: int32, int64.
2829      with only 1 element which is the axis index.
2830      TODO: Add support for multiple elements.
2831  }];
2832
2833  let arguments = (
2834    ins
2835    TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$input,
2836    TFL_I32Tensor:$axis
2837  );
2838
2839  let results = (outs
2840    TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$output);
2841}
2842
2843// Select has many instances in TF models where one or more of its operands
2844// are unranked. Therefore, we skip adding shape constraints here.
2845def TFL_SelectOp : TFL_Op<"select", [
2846  NoSideEffect,
2847  SameOperandsAndResultsScale,
2848  PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
2849  PredOpTrait<"operands and result have same element type",
2850    TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
2851  let summary = "Select operator";
2852
2853  let description = [{
2854    Select values of 'x' if the corresponding value of 'condition' is true or
2855    the value of 'y' if false. There are valid condition input sizes:
2856
2857    1. Either the same shape (in which case the select is elementwise), or
2858    2. condition must be Rank 1 and match over the first dimension.
2859  }];
2860
2861  let arguments = (ins
2862    TFL_BoolTensor:$condition,
2863    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
2864    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
2865
2866  let results = (outs
2867    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
2868
2869  // TODO(jpienaar): autogenerate this.
2870  let builders = [
2871    OpBuilder<(ins "Value":$condition, "Value":$x, "Value":$y),
2872    [{
2873    auto resultType = x.getType();
2874    $_state.addOperands({condition, x, y});
2875    $_state.types.push_back(resultType);
2876  }]>];
2877
2878  let hasOptions = 1;
2879}
2880
2881def TFL_SelectV2Op : TFL_Op<"select_v2", [
2882    ResultsBroadcastableShape,
2883    NoSideEffect,
2884    SameOperandsAndResultsScale,
2885    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
2886    PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
2887    PredOpTrait<"operands and result have same element type",
2888      TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
2889  let summary = "SelectV2 operator";
2890
2891  let description = [{
2892    Select values of 'x' if the corresponding value of 'condition' is true or
2893    the value of 'y' if false. There are valid condition input sizes:
2894
2895    1. Either the same shape (in which case the select is elementwise), or
2896    2. Broadcastable shapes between 'condition', 'x' and 'y'.
2897  }];
2898
2899  let arguments = (ins
2900    TFL_BoolTensor:$condition,
2901    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
2902    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
2903
2904  let results = (outs
2905    TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
2906
2907  let builders = [
2908    OpBuilder<(ins "Value":$cond, "Value":$x, "Value":$y),
2909    [{
2910    BuildSelectV2Op(&$_builder, $_state, cond, x, y);
2911  }]>];
2912
2913  let hasOptions = 1;
2914}
2915
2916def TFL_SinOp: TFL_Op<"sin", [
2917    NoSideEffect,
2918    SameOperandsAndResultShape,
2919    SameOperandsAndResultType,
2920    NoQuantizableResult]> {
2921  let summary = "Sine operator";
2922
2923  let description = [{
2924    Computes element-wise Sine of input
2925  }];
2926
2927  let arguments = (ins TFL_FpTensor:$x);
2928
2929  let results = (outs TFL_FpTensor:$y);
2930
2931  let hasFolder = 1;
2932}
2933
2934def TFL_SoftmaxOp : TFL_Op<"softmax", [
2935    NoSideEffect,
2936    PredOpTrait<"input and output must have same element type",
2937      TFL_TCresVTEtIsSameAsOp<0, 0>>,
2938    TFL_OperandHasRankRange<0, 1, 4>,
2939    SameOperandsAndResultShape,
2940    FixedOutputRangeInterface,
2941    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
2942  let summary = "Softmax operator";
2943
2944  let description = [{
2945    Computes element-wise softmax activations with the following formula
2946
2947      exp(input) / tf.reduce_sum(exp(input * beta), dim)
2948  }];
2949
2950  let arguments = (
2951    ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
2952    F32Attr:$beta
2953  );
2954
2955  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
2956
2957  let hasOptions = 1;
2958
2959  let extraClassDeclaration = [{
2960  // FixedOutputRangeInterface:
2961  quant::UniformQuantizedType GetFixedOutputRange(
2962      bool is_signed, int bit_width) {
2963    auto result_type = output().getType();
2964    // zero_point = 0
2965    // scale = 1. / (max_value + 1)
2966    return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
2967        /*scale=*/1.0 / 256, /*zero_point=*/-128);
2968  }
2969  }];
2970}
2971
2972def TFL_SqrtOp: TFL_Op<"sqrt", [
2973    NoSideEffect,
2974    SameOperandsAndResultShape,
2975    SameOperandsAndResultType,
2976    NoQuantizableResult]> {
2977  let summary = "Square root operator";
2978
2979  let description = [{
2980    Computes element-wise Square root of input
2981  }];
2982
2983  let arguments = (ins TFL_FpTensor:$x);
2984
2985  let results = (outs TFL_FpTensor:$y);
2986
2987  let hasFolder = 1;
2988}
2989
2990def TFL_SquareOp: TFL_Op<"square", [
2991    NoSideEffect,
2992    SameOperandsAndResultShape,
2993    SameOperandsAndResultType,
2994    NoQuantizableResult]> {
2995  let summary = "Square operator";
2996
2997  let description = [{
2998    Computes element-wise Square of input
2999  }];
3000
3001  let arguments = (ins TFL_FpTensor:$x);
3002
3003  let results = (outs TFL_FpTensor:$y);
3004
3005  let hasOptions = 0b1;
3006
3007  let hasFolder = 1;
3008}
3009
3010def TFL_SubOp : TFL_Op<"sub", [
3011    ResultsBroadcastableShape,
3012    BinaryOpSameElementTypeConstraint,
3013    TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
3014      CPred<"TFL::VerifySubOpShapeConstraints(llvm::cast<SubOp>($_op))">>,
3015    NoSideEffect,
3016    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
3017  let summary = "Subtraction operator";
3018
3019  let description = [{
3020    Element-wise subtraction operation.
3021  }];
3022
3023  let arguments = (
3024    ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
3025    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
3026    TFL_AFAttr:$fused_activation_function);
3027
3028  let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
3029
3030  let hasFolder = 1;
3031
3032  let builders = [TFL_FusedBroadcastableBinaryBuilder];
3033
3034  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
3035
3036  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
3037
3038  let hasOptions = 1;
3039}
3040
3041def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
3042    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
3043    BinaryOpSameElementTypeConstraint,
3044    TFL_SameFirstOperandAndFirstResultElementType,
3045    ResultsBroadcastableShape,
3046    NoSideEffect]> {
3047  let summary = "Squared difference operator";
3048
3049  let description = [{
3050    Element-wise squared difference operation.
3051  }];
3052
3053  let arguments = (
3054    ins TFL_TensorOf<[F32, I32, QI8]>:$lhs,
3055    TFL_TensorOf<[F32, I32, QI8]>:$rhs);
3056
3057  let results = (outs TFL_TensorOf<[F32, I32, QI8]>:$output);
3058
3059  let builders = [TFL_BroadcastableBinaryBuilder];
3060
3061  let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
3062
3063  let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
3064}
3065
3066def TFL_TanhOp: TFL_Op<"tanh", [
3067    NoSideEffect,
3068    SameOperandsAndResultShape,
3069    PredOpTrait<"input and output must have same element type",
3070      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3071    FixedOutputRangeInterface,
3072    DeclareOpInterfaceMethods<TFL_ArithmeticCount>]> {
3073  let summary = "Hyperbolic tangent operator";
3074
3075  let description = [{
3076    Computes element-wise Hyperbolic tangent of input
3077  }];
3078
3079  let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input);
3080
3081  let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$output);
3082
3083  // This builder doesn't work with quantized type, so it can only be used by
3084  // non-quantization tablegen patterns. Currently, it is used by the
3085  // elementwise-move reordering pattern in the optimize_patterns.td
3086  let builders = [
3087    OpBuilder<(ins "Value":$input),
3088    [{
3089      $_state.addOperands({input});
3090      $_state.addTypes(input.getType());
3091    }]>
3092  ];
3093
3094  let extraClassDeclaration = [{
3095  // FixedOutputRangeInterface:
3096  quant::UniformQuantizedType GetFixedOutputRange(
3097      bool is_signed, int bit_width) {
3098    auto result_type = output().getType();
3099    // central_value = min_value / 2 + (max_value - 1) / 2 + 1
3100    // zero_point = central_value
3101    // scale = 1. / (central_value - min_value)
3102    return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
3103        /*scale=*/1.0 / 128, /*zero_point=*/0);
3104  }
3105  }];
3106}
3107
3108def TFL_TileOp: TFL_Op<"tile", [
3109    NoSideEffect,
3110    SameOperandsAndResultsScale,
3111    PredOpTrait<"input and output must have same element type",
3112      TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
3113  let summary = "Tile operator.";
3114  let description = [{
3115    Constructs a tensor by tiling a given tensor.
3116
3117   This operation creates a new tensor by replicating input
3118   multiples times. The output tensor's i'th dimension has
3119   input.dims(i) * multiples[i] elements, and the values of input
3120   are replicated multiples[i] times along the 'i'th dimension.
3121   For example, tiling [a b c d] by [2] produces [a b c d a b c d].
3122  }];
3123
3124  let arguments = (ins
3125    TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$input,
3126    TFL_I32OrI64Tensor:$multiples);
3127
3128  let results = (outs
3129    TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$output);
3130
3131  let hasOptions = 0;
3132}
3133
3134// TODO(jpienaar): Maybe make it accept any single element tensor as `k`.
3135// TODO(jpienaar): Check that input has one or more dimensions.
3136// TODO(jpienaar): Check that k is less or equal the internal dimension
3137def TFL_TopKV2Op: TFL_Op<"topk_v2", [
3138    NoSideEffect,
3139    TFL_OperandHasRankAtLeast<0, 1>,
3140    TFL_OperandHasRank<1, 0>,
3141    PredOpTrait<"result and input element type match",
3142      TFL_TCresVTEtIsSameAsOp<0,0>>,
3143    SameOperandsAndResultsScale]> {
3144  let summary = "TopK operator";
3145
3146  let description = [{
3147    Returns the top `k` largest element along each last dimensional slice of
3148    `input` and the indices of values within the last dimension of the input
3149    tensor.
3150  }];
3151
3152  let arguments = (ins
3153    TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input,
3154    TFL_I32Tensor:$k);
3155
3156  let results = (outs
3157    TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values,
3158    TFL_I32Tensor:$indices);
3159
3160  let builders = [
3161    OpBuilder<(ins "Value":$input, "Value":$k),
3162    [{ BuildTopKOp(&$_builder, $_state, input, k); }]>];
3163
3164  let hasOptions = 1;
3165}
3166
3167def TFL_TransposeOp : TFL_Op<"transpose", [
3168    NoSideEffect,
3169    TFL_OperandHasRankAtMost<0, 5>,
3170    TFL_OperandHasRank<1, 1>,
3171    PredOpTrait<"input and output must have same element type",
3172      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3173    SameOperandsAndResultsScale]> {
3174  let summary = "Transpose operator";
3175
3176  let description = [{
3177    Returns the Transpose of x
3178  }];
3179
3180  let arguments = (ins
3181    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
3182    TFL_TensorOf<[I32]>:$perm
3183  );
3184
3185  let results = (outs
3186    TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
3187  );
3188
3189  let verifier = [{ return Verify(*this); }];
3190
3191  let hasFolder = 1;
3192
3193  let builders = [
3194    OpBuilder<(ins "Value":$input, "Value":$perm),
3195    [{ BuildTransposeOp(&$_builder, $_state, input, perm); }]>
3196  ];
3197}
3198
3199def TFL_UnpackOp : TFL_Op<"unpack", [
3200    NoSideEffect,
3201    SameOperandsAndResultElementType,
3202    SameOperandsAndResultsScale,
3203    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
3204  let summary = "Unpacks a tensor along a dimension into multiple tensors";
3205
3206  let description = [{
3207    Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
3208
3209    Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
3210    For example, given a tensor of shape `(A, B, C, D)`;
3211
3212    If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]`
3213      and each tensor in `output` will have shape `(B, C, D)`. (Note that the
3214      dimension unpacked along is gone, unlike `split`).
3215
3216    If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]`
3217      and each tensor in `output` will have shape `(A, C, D)`.
3218    Etc.
3219
3220    This is the opposite of `pack`.
3221  }];
3222
3223  let arguments = (ins
3224    TFL_TensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$input,
3225
3226    Confined<I32Attr, [IntNonNegative]>:$num,
3227    I32Attr:$axis
3228  );
3229
3230  let results = (outs
3231    TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
3232  );
3233
3234  let extraClassDeclaration = [{
3235    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
3236  }];
3237
3238  let hasOptions = 1;
3239}
3240
3241def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [
3242    PredOpTrait<"input and output must have same element type",
3243      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3244    SameOperandsAndResultType,
3245    SameOperandsAndResultShape,
3246    NoSideEffect,
3247    NoQuantizableResult]> {
3248  let summary = "ZerosLike operator";
3249
3250  let description = [{
3251    Returns a tensor of zeros with the same shape and type as the input tensor.
3252  }];
3253
3254  let arguments = (ins TFL_TensorOf<[I64, I32, F32]>:$input);
3255
3256  let results = (outs TFL_TensorOf<[I64, I32, F32]>:$output);
3257
3258  let hasOptions = 1;
3259}
3260
3261def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
3262    NoSideEffect,
3263    SameOperandsAndResultsScale,
3264    PredOpTrait<"input and output must have same element type",
3265      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3266    TFL_OperandHasRankRange<0, 3, 4>,
3267    TFL_OperandHasRank<1, 1>,
3268    TFL_OperandHasRank<2, 2>
3269  ]> {
3270  let summary = "BatchToSpaceNd operator";
3271
3272  let description = [{
3273    This operation reshapes the "batch" dimension 0 into space dimensions.
3274  }];
3275
3276  let arguments = (ins
3277    TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input,
3278    TFL_TensorOf<[I32]>:$block_shape,
3279    TFL_TensorOf<[I32]>:$indices
3280  );
3281
3282  let results = (outs
3283    TFL_TensorOf<[F32, I16, I32, I64, UI8, QI8, QUI8]>:$output
3284  );
3285}
3286
3287def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
3288    NoSideEffect,
3289    SameOperandsAndResultsScale,
3290    TFL_OperandHasRankRange<0, 3, 4>,
3291    PredOpTrait<"input and output must have same element type",
3292      TFL_TCresVTEtIsSameAsOp<0, 0>>
3293  ]> {
3294  let summary = "SpaceToBatchNd operator";
3295
3296  let description = [{
3297    This operation reshapes space dimensions into the "batch" dimension 0
3298  }];
3299
3300  let arguments = (ins
3301    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
3302    TFL_I32Tensor:$block_shape,
3303    TFL_I32Tensor:$paddings
3304  );
3305
3306  let results = (outs
3307    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output
3308  );
3309}
3310
3311def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
3312    NoSideEffect,
3313    SameOperandsAndResultsScale,
3314    PredOpTrait<"input and output must have same element type",
3315      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3316    TFL_OperandHasRankAtMost<0, 4>
3317  ]> {
3318  let summary = "SpaceToDepth operator";
3319
3320  let description = [{
3321    Rearranges blocks of spatial data, into depth. More specifically,
3322    this op outputs a copy of the input tensor where values from the `height`
3323    and `width` dimensions are moved to the `depth` dimension.
3324    `block_size` indicates the input block size.
3325   }];
3326
3327  let arguments = (ins
3328    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
3329    Confined<I32Attr, [IntPositive]>:$block_size
3330  );
3331
3332  let results = (outs
3333    TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output
3334  );
3335
3336  let hasOptions = 1;
3337}
3338
3339def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
3340    NoSideEffect,
3341    SameOperandsAndResultsScale,
3342    PredOpTrait<"input and output must have same element type",
3343      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3344    TFL_OperandHasRankAtMost<0, 4>
3345  ]> {
3346  let summary = "DepthToSpace operator";
3347
3348  let description = [{
3349    Rearranges data from depth into blocks of spatial data.
3350    This is the reverse transformation of SpaceToDepth. More specifically,
3351    this op outputs a copy of the input tensor where values from the `depth`
3352    dimension are moved in spatial blocks to the `height` and `width`
3353    dimensions. The attr `block_size` indicates the input block size and how
3354    the data is moved.
3355   }];
3356
3357  let arguments = (ins
3358    TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$input,
3359    Confined<I32Attr, [IntPositive]>:$block_size
3360  );
3361
3362  let results = (outs
3363    TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$output
3364  );
3365
3366  let hasOptions = 1;
3367}
3368
3369def TFL_SplitOp : TFL_Op<"split", [
3370    NoSideEffect,
3371    TFL_Operand0DOr1ElementTensor<0>,
3372    SameOperandsAndResultsScale]> {
3373  let summary = "Splits a tensor into `num_split` tensors along one dimension.";
3374
3375  let description = [{
3376    Splits the `value` tensor along `split_dim` into a number of sub-tensors
3377    with same shape as the original one, except for `split_dim`. Same as
3378    tf.Split.
3379  }];
3380
3381  let arguments = (ins
3382    TFL_TensorOf<[I32]>:$split_dim,
3383    TFL_TensorOf<[F32, I16, I32, I8, UI8, QI8, QUI8, QI16]>:$value,
3384    Confined<I32Attr, [IntPositive]>:$num_splits
3385  );
3386
3387  let results = (outs
3388    TFL_VariadicTensorOf<[F32, I16, I32, I8, UI8, QI8, QUI8, QI16]>:$outputs
3389  );
3390
3391  let verifier = [{ return Verify(*this); }];
3392
3393  let hasOptions = 1;
3394}
3395
3396def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]> {
3397  let summary = "Splits a tensor into `num_split` tensors along one dimension.";
3398
3399  let description = [{
3400    Splits the `value` tensor along `split_dim` into a number of sub-tensors
3401    with same shape as the original one, except for `split_dim`. The grouping
3402    of the resultant sub-tensors is decided by `size-splits`. Same as tf.SplitV.
3403  }];
3404
3405  let arguments = (ins
3406    TFL_TensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$value,
3407    TFL_1DTensorOf<[I32], [I32]>:$size_splits,
3408    TFL_0DTensorOf<[I32], [I32]>:$split_dim,
3409    Confined<I32Attr, [IntPositive]>:$num_splits
3410  );
3411
3412  let results = (outs
3413    TFL_VariadicTensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$outputs
3414  );
3415
3416  let verifier = [{ return Verify(*this); }];
3417
3418  let hasOptions = 1;
3419}
3420
3421def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
3422    NoSideEffect,
3423    PredOpTrait<"input and output must have same element type",
3424      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3425    TFL_OperandHasRank<0, 4>,
3426    TFL_OperandHasRank<1, 1>,
3427    SameOperandsAndResultsScale]> {
3428  let summary = "ResizeBilinear Op";
3429
3430  let description = [{
3431    Resize `images` to `size` using bilinear interpolation.
3432  }];
3433
3434  let arguments = (ins
3435    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
3436    TFL_I32Tensor:$size,
3437    BoolAttr:$align_corners,
3438    DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
3439  );
3440
3441  let results = (outs
3442    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
3443  );
3444
3445  let hasOptions = 1;
3446}
3447
3448def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
3449    NoSideEffect,
3450    PredOpTrait<"input and output must have same element type",
3451      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3452    TFL_OperandHasRank<0, 4>,
3453    TFL_OperandHasRank<1, 1>,
3454    SameOperandsAndResultsScale]> {
3455  let summary = "ResizeNearestNeighbor Op";
3456
3457  let description = [{
3458    Resize `images` to `size` using nearest neighbor interpolation.
3459  }];
3460
3461  let arguments = (ins
3462    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
3463    TFL_I32Tensor:$size,
3464    BoolAttr:$align_corners,
3465    DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
3466  );
3467
3468  let results = (outs
3469    TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
3470  );
3471
3472  let hasOptions = 1;
3473}
3474
3475def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [
3476    NoSideEffect,
3477    PredOpTrait<"sparse_values and dense must have same element type",
3478      TFL_TCresVTEtIsSameAsOp<0, 2>>,
3479    PredOpTrait<"default_value and dense must have same element type",
3480      TFL_TCresVTEtIsSameAsOp<0, 3>>,
3481    TFL_OperandHasRankAtMost<0, 2>,
3482    TFL_OperandHasRankAtMost<1, 1>,
3483    TFL_OperandHasRankAtMost<2, 1>]> {
3484  let summary = "Converts a sparse representation into a dense tensor.";
3485
3486  let description = [{
3487Builds an array `dense` with shape `output_shape` such that
3488
3489```
3490# If sparse_indices is scalar
3491dense[i] = (i == sparse_indices ? sparse_values : default_value)
3492
3493# If sparse_indices is a vector, then for each i
3494dense[sparse_indices[i]] = sparse_values[i]
3495
3496# If sparse_indices is an n by d matrix, then for each i in [0, n)
3497dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
3498```
3499
3500All other values in `dense` are set to `default_value`.  If `sparse_values` is a
3501scalar, all sparse indices are set to this single value.
3502
3503Indices should be sorted in lexicographic order, and indices must not
3504contain any repeats. If `validate_indices` is true, these properties
3505are checked during execution.
3506  }];
3507
3508  let arguments = (ins
3509    TFL_I32OrI64Tensor:$sparse_indices,
3510    TFL_I32OrI64Tensor:$output_shape,
3511    TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$sparse_values,
3512    TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$default_value
3513  );
3514
3515  let results = (outs
3516    TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$dense
3517  );
3518}
3519
3520def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
3521    NoSideEffect,
3522    PredOpTrait<"input and output must have same element type",
3523      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3524    SameOperandsAndResultsScale,
3525    TFL_OperandHasRankAtMost<0, 5>,
3526    TFL_OperandHasRank<1, 1>,
3527    TFL_OperandHasRank<2, 1>,
3528    TFL_OperandHasRank<3, 1>
3529  ]> {
3530  let summary = "StridedSlice Op";
3531
3532  let description = [{
3533    Return a strided slice from `input`.
3534  }];
3535
3536  let arguments = (ins
3537    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
3538    TFL_I32Tensor:$begin,
3539    TFL_I32Tensor:$end,
3540    TFL_I32Tensor:$strides,
3541
3542    I32Attr:$begin_mask,
3543    I32Attr:$end_mask,
3544    I32Attr:$ellipsis_mask,
3545    I32Attr:$new_axis_mask,
3546    I32Attr:$shrink_axis_mask
3547  );
3548
3549  let results = (outs
3550    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
3551  );
3552
3553  // TFLite kernel only supports up to 5D input including added axis.
3554  let verifier = [{ return Verify(*this); }];
3555
3556  let hasOptions = 1;
3557
3558  let hasFolder = 1;
3559}
3560
3561// If there is a change in supporting more types in the TFLite cast op kernel,
3562// the While loop outline pass should be updated since it inserts cast op(s)
3563// after the TF -> TFL legalization pass is done.
3564// LINT.IfChange
3565def TFL_CastOp : TFL_Op<"cast", [
3566    NoSideEffect,
3567    SameOperandsAndResultShape,
3568    NoQuantizableResult]> {
3569  let summary = "Cast operator";
3570
3571  let description = [{
3572    Casts input from input type to output type.
3573  }];
3574
3575  let arguments = (ins
3576    TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
3577  );
3578
3579  let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
3580
3581  // TFLite's cast op does not utilize CastOptions, instead derives types
3582  // from the TfLiteTensors.
3583  let hasOptions = 0;
3584
3585  let hasFolder = 1;
3586}
3587// LINT.ThenChange(//tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc)
3588
3589def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
3590                     NoSideEffect, TFL_OperandHasRank<1, 2>]> {
3591  let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
3592
3593  let description = [{
3594    This operation pads a input with mirrored values according to the paddings
3595    you specify. paddings is an integer tensor with shape [n, 2],
3596    where n is the rank of input.
3597    For each dimension D of input, paddings[D, 0] indicates how many values
3598    to add before the contents of input in that dimension,
3599    and paddings[D, 1] indicates how many values to add after the contents of
3600    input in that dimension.
3601
3602    Both paddings[D, 0] and paddings[D, 1] must be no greater than
3603    input.dim_size(D) (or input.dim_size(D) - 1)
3604    if copy_border is true (if false, respectively).
3605
3606    The padded size of each dimension D of the output is:
3607
3608    paddings(D, 0) + input.dim_size(D) + paddings(D, 1)
3609  }];
3610
3611  let arguments = (ins
3612    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$input,
3613    TFL_TensorOf<[I32, I64]>:$pad,
3614    TFL_MirrorPaddingAttr:$mode
3615  );
3616
3617  let results = (outs
3618    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$output
3619  );
3620
3621  let hasOptions = 1;
3622}
3623
3624def TFL_UniqueOp: TFL_Op<"unique", [
3625    TFL_OperandHasRank<0, 1>,
3626    NoSideEffect]> {
3627  let summary = "Unique Op.";
3628
3629  let description = [{
3630  This operation returns a tensor `output` containing all of the unique elements
3631of `input` sorted in the same order that they occur in `input`. This operation
3632also returns a tensor `idx` the same size as `x` that contains the index of each
3633value of `input` in the unique output `output`. In other words:
3634  }];
3635
3636  let arguments = (ins
3637    TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$input
3638  );
3639
3640  let results = (outs
3641    TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$output,
3642    TFL_I32OrI64Tensor:$idx
3643  );
3644
3645  DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
3646    return getResult(1).getType().cast<TensorType>().getElementType().
3647        cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
3648            tflite::TensorType_INT32;
3649    }], [{
3650      TypeAttr::get(getResult(1).getType().cast<TensorType>().getElementType())
3651    }]>;
3652
3653  let hasOptions = 1;
3654}
3655
3656//===----------------------------------------------------------------------===//
3657// Quantization ops.
3658//===----------------------------------------------------------------------===//
3659def TFL_DequantizeOp: TFL_Op<"dequantize", [NoQuantizableResult]> {
3660  let summary = "Dequantize operator";
3661
3662  let description = [{
3663    Converts quantized array of integers to floating-points according to the
3664    quantization parameters.
3665  }];
3666
3667  let arguments = (ins TFL_TensorOf<[QI8, QUI8, QI16, F16]>:$input);
3668
3669  let results = (outs TFL_FpTensor:$output);
3670}
3671
3672def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
3673  let summary = "FakeQuant operator";
3674
3675  let description = [{
3676    Fake-quantize the 'inputs' tensor of type float via float scalars min and
3677    max to 'outputs' tensor of same shape as inputs.
3678  }];
3679
3680  let arguments = (
3681    ins TFL_FpTensor:$input,
3682    // The expected [min, max] range of values.
3683    F32Attr:$min,
3684    F32Attr:$max,
3685
3686    // The bitwidth of the quantization; between 2 and 16, inclusive.
3687    Confined<I32Attr, [IntMinValue<2>, IntMaxValue<16>]>:$num_bits,
3688    // Quantization range starts from 0 or 1; starts from 1 if true.
3689    Confined<BoolAttr, [TFL_BoolFalse]>:$narrow_range);
3690
3691  let results = (outs TFL_FpTensor:$output);
3692
3693  let hasCanonicalizer = 0b1;
3694
3695  let hasOptions = 1;
3696}
3697
3698def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
3699    NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
3700  let summary = "Quantized constant pseudo op";
3701
3702  let description = [{
3703    Represents a quantized constant value in TensorFlow Lite dialect. This is
3704    not an actual operation and it will be lowered to buffer instead. The
3705    quantization parameters are stored as a type attribute in this constant.
3706  }];
3707
3708  let arguments = (
3709    ins TensorTypeAttr:$qtype,
3710    ElementsAttr:$value
3711  );
3712
3713  let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
3714
3715  let builders = [
3716    OpBuilder<(ins "TypeAttr":$qtype, "Attribute":$value),
3717    [{
3718      $_state.addAttribute("qtype", qtype);
3719      $_state.addAttribute("value", value);
3720      $_state.addTypes(qtype.getValue());
3721    }]>
3722  ];
3723}
3724
3725def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
3726    NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
3727  let summary = "Sparse quantized constant pseudo op";
3728
3729  let description = [{
3730    Represents a sparse quantized constant value in TensorFlow Lite dialect.
3731    This is not an actual operation and it will be lowered to buffer instead.
3732    The quantization parameters are stored as a type attribute in this constant.
3733  }];
3734
3735  let arguments = (
3736    ins TensorTypeAttr:$qtype,
3737    ElementsAttr:$value,
3738    SparsityParameterAttr:$s_param,
3739    ElementsAttr:$compressed_data
3740  );
3741
3742  let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
3743
3744  let builders = [
3745    OpBuilder<(ins "TypeAttr":$qtype, "Attribute":$value,
3746      "SparsityParameterAttr":$s_param, "Attribute":$compressed_data),
3747    [{
3748      $_state.addTypes(qtype.getValue());
3749      $_state.addAttribute("qtype", qtype);
3750      $_state.addAttribute("value", value);
3751      $_state.addAttribute("s_param", s_param);
3752      $_state.addAttribute("compressed_data", compressed_data);
3753    }]>
3754  ];
3755}
3756
3757def TFL_QuantizeOp: TFL_Op<"quantize", [
3758    FirstAttrDerivedResultType,
3759    SameOperandsAndResultShape,
3760    NoQuantizableResult]> {
3761  let summary = "Quantize operator";
3762
3763  let description = [{
3764    Converts floating point tensors to quantized integer tensors according to
3765    the quantization parameters defined in the type attribute.
3766  }];
3767
3768  let arguments = (
3769    ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input,
3770    TensorTypeAttr:$qtype
3771  );
3772
3773  let results = (outs TFL_TensorOf<[QI8, QUI8, QI16, TFL_Quint8]>:$output);
3774}
3775
3776def TFL_DensifyOp: TFL_Op<"densify", [
3777    NoSideEffect,
3778    PredOpTrait<"input and output must have same element type",
3779      TFL_TCresVTEtIsSameAsOp<0, 0>>,
3780    NoQuantizableResult]> {
3781  let summary = "Densify operator";
3782
3783  let description = [{
3784    Converts sparse tensor to dense format.
3785  }];
3786
3787  let arguments = (ins TFL_TensorOf<[F32, I8]>:$input);
3788
3789  let results = (outs TFL_TensorOf<[F32, I8]>:$output);
3790}
3791
3792//===----------------------------------------------------------------------===//
3793// LSTM Ops
3794//===----------------------------------------------------------------------===//
3795
3796// LSTM Kernel Type attributes
3797def TFL_LSTM_KT_FULL  : StrEnumAttrCase<"FULL">;
3798def TFL_LSTM_KT_BASIC  : StrEnumAttrCase<"BASIC">;
3799
3800def TFL_LSTMKernelTypeAttr : StrEnumAttr<"LSTMKernelType", "lstm kernel type enum",
3801   [
3802     TFL_LSTM_KT_FULL,  TFL_LSTM_KT_BASIC
3803   ]>;
3804
3805def LstmMandatoryInputsConstraint : PredOpTrait<
3806  "mandatory operands element types should match",
3807  // TODO(ashwinm): Replace the indices with input tensor names when that
3808  // support is available.
3809  Or<[
3810    TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>,
3811    Neg<TypeIsPred<"input", F32>>]>>;
3812
3813def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
3814  "the optional peephole weights should all be specified or none",
3815  // Ignore input 9 (cell_to_input_weights) for LSTM with CIFG.
3816  And<[
3817    TFL_TCopVTEtAreSameAt<10, 11, 16>,
3818    Or<[TFL_TCopVTEtAreSameAt<9, 10, 16>,
3819        And<[TypeIsPred<"input_to_input_weights", NoneType>,
3820             TypeIsPred<"cell_to_input_weights", NoneType>]>]>]>>;
3821
3822def LstmProjectionWeightBiasConstraint : PredOpTrait<
3823  "either projection weight must be specified or both projection weight and "
3824  "projection bias must not be specified",
3825   Or<[
3826      And<[TypeIsPred<"projection_weights", NoneType>,
3827           TypeIsPred<"projection_bias", NoneType>]>,
3828      Neg<TypeIsPred<"projection_weights", NoneType>>]>>;
3829
3830def LstmCifgInputConstraint : PredOpTrait<
3831  "the cifg inputs should all be specified or none",
3832   // If LSTM has combined input/forget gate, input 1, 5, 9, 12, 20 are all none
3833   // or 1, 5, 12 should not be none. Inputs 9 and 20 depend on LSTM's variants.
3834   Or<[
3835     And<[TypeIsPred<"input_to_input_weights", NoneType>,
3836          TypeIsPred<"recurrent_to_input_weights", NoneType>,
3837          TypeIsPred<"cell_to_input_weights", NoneType>,
3838          TypeIsPred<"input_gate_bias", NoneType>,
3839          TypeIsPred<"input_layer_norm_coefficients", NoneType>]>,
3840     Neg<Or<[
3841       TypeIsPred<"input_to_input_weights", NoneType>,
3842       TypeIsPred<"recurrent_to_input_weights", NoneType>,
3843       TypeIsPred<"input_gate_bias", NoneType>]>>]>>;
3844
3845
3846// TODO(b/137798843): Need to add an additional constraint for both LSTM and
3847// UnidirectionalSequenceLstm
3848// For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
3849// are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor
3850// {20} is not null if additionally cifg = false.
3851
3852def LstmResultConstraint : PredOpTrait<
3853  "the input and result tensor elemental types must be same",
3854  TFL_TCresVTEtIsSameAsOp<0, 0>>;
3855
3856// This is the basic kernel type LSTM op.
3857// TODO(b/142417845): Refactor this part to return its tflite node name as
3858// "lstm".
3859def TFL_BasicLSTMOp : TFL_Op<"basic_lstm", [NoSideEffect,
3860    TFL_OperandHasRank<0, 2>, TFL_OperandHasRank<1, 2>, TFL_OperandHasRank<2, 2>,
3861    TFL_OperandHasRank<3, 1>, TFL_OperandHasRank<4, 2>]> {
3862  let summary = "The basic lstm operator";
3863
3864  let description = [{
3865    basic LSTM Cell Operator.
3866  }];
3867
3868  let arguments = (
3869    ins TFL_TensorOf<[F32, QUI8]>:$data_input,
3870    TFL_TensorOf<[F32, QUI8]>:$prev_activ_input,
3871    TFL_TensorOf<[F32, QUI8]>:$weights_input,
3872    TFL_TensorOf<[F32, QI32]>:$biases_input,
3873    TFL_TensorOf<[F32, QI16]>:$prev_state_input,
3874
3875    // Attributes
3876    DefaultValuedAttr<TFL_AFAttr, "TANH">:$fused_activation_function,
3877    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
3878    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
3879    // Since this op is the BASIC kernel only, constrain it.
3880    Confined<
3881      DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "BASIC">,
3882      [TFL_LSTM_KT_BASIC]>:$kernel_type
3883  );
3884
3885  let hasOptions = 1;
3886
3887  let results = (outs TFL_2DTensorOf<[F32, QUI8]>:$activ_output,
3888                      TFL_2DTensorOf<[F32, QUI16]>:$state_output,
3889                      TFL_2DTensorOf<[F32, QUI8]>:$concat_temp,
3890                      TFL_2DTensorOf<[F32, QUI16]>:$activ_temp);
3891}
3892
3893// This is the FULL kernel type LSTM op.
3894def TFL_LSTMOp :
3895  TFL_Op<"lstm",
3896          [LstmMandatoryInputsConstraint,
3897           LstmOptionalPeepholeWeightConstraint,
3898           LstmProjectionWeightBiasConstraint,
3899           LstmCifgInputConstraint,
3900           LstmResultConstraint,
3901           TFL_OperandHasRank<2, 2>,           // input_to_forget_weights
3902           TFL_OperandHasRank<3, 2>,           // input_to_cell_weights
3903           TFL_OperandIsNoneOrHasRank<5, 2>,   // recurrent_to_input_weights
3904           TFL_OperandHasRank<6, 2>,           // recurrent_to_forget_weights
3905           TFL_OperandHasRank<7, 2>,           // recurrent_to_cell_weights
3906           TFL_OperandIsNoneOrHasRank<9, 1>,   // cell_to_input_weights
3907           TFL_OperandIsNoneOrHasRank<10, 1>,  // cell_to_forget_weights
3908           TFL_OperandIsNoneOrHasRank<11, 1>,  // cell_to_output_weights
3909           TFL_OperandHasRank<13, 1>,          // forget_gate_bias
3910           TFL_OperandHasRank<14, 1>,          // cell_gate_bias
3911           TFL_OperandHasRank<15, 1>,          // output_gate_bias
3912           TFL_OperandIsNoneOrHasRank<16, 2>,  // projection_weights
3913           TFL_OperandIsNoneOrHasRank<17, 1>,  // projection_bias
3914           TFL_StatefulOp]> {
3915  let summary = "The full lstm operator";
3916
3917  let description = [{
3918Long short-term memory unit (LSTM) recurrent network layer.
3919The default non-peephole implementation is based on:
3920http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
3921S. Hochreiter and J. Schmidhuber. 'Long Short-Term Memory'. Neural Computation,
39229(8):1735-1780, 1997.
3923The peephole implementation is based on:
3924https://research.google.com/pubs/archive/43905.pdf
3925Hasim Sak, Andrew Senior, and Francoise Beaufays. 'Long short-term memory
3926recurrent neural network architectures for large scale acoustic modeling.'
3927INTERSPEECH, 2014.
3928The coupling of input and forget gate (CIFG) is based on:
3929http://arxiv.org/pdf/1503.04069.pdf
3930Greff et al. 'LSTM: A Search Space Odyssey'
3931The layer normalization is based on:
3932https://arxiv.org/pdf/1607.06450.pdf
3933Ba et al. 'Layer Normalization'
3934  }];
3935
3936  let arguments = (
3937    ins TFL_TensorOf<[F32, QI8]>:$input,
3938
3939    // Weights
3940    TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
3941    TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
3942    TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
3943    TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
3944
3945    // Recurrent weights
3946    TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
3947    TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
3948    TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
3949    TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
3950
3951    // Cell weights
3952    TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_input_weights,
3953    // Optional input
3954    TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_forget_weights,
3955    // Optional input
3956    TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_output_weights,
3957
3958    // Bias
3959    TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
3960    TFL_TensorOf<[F32, QI32]>:$forget_gate_bias,
3961    TFL_TensorOf<[F32, QI32]>:$cell_bias,
3962    TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
3963
3964    // Projection weight and bias
3965    TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
3966    // Optional input
3967    TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
3968
3969    // Stateful activation and cell states.
3970    TFL_StatefulTensor:$input_activation_state,
3971    TFL_StatefulTensor:$input_cell_state,
3972
3973    // Layer norm coefficients
3974    TFL_TensorOfOrNone<[F32, QI16]>:$input_layer_norm_coefficients,
3975    TFL_TensorOfOrNone<[F32, QI16]>:$forget_layer_norm_coefficients,
3976    TFL_TensorOfOrNone<[F32, QI16]>:$cell_layer_norm_coefficients,
3977    TFL_TensorOfOrNone<[F32, QI16]>:$output_layer_norm_coefficients,
3978
3979    // Attributes
3980    TFL_AFAttr:$fused_activation_function,
3981    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
3982    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
3983    // Since this op is the FULL kernel only, constrain it.
3984    Confined<
3985      DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
3986      [TFL_LSTM_KT_FULL]>:$kernel_type,
3987
3988    // Types of the optional intermediate tensors, which exist for fully
3989    // quantized LSTM op and hold the ranges of the intermediate tensors.
3990    // The type for intermediate tenssors are be quant.calibrated when imported
3991    // to only store calibrated min, max values. The proper quantization spec is
3992    // determined while going through quantization passes.
3993    OptionalAttr<TypeAttr>:$input_to_input_intermediate,
3994    OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
3995    OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
3996    OptionalAttr<TypeAttr>:$input_to_output_intermediate,
3997    OptionalAttr<TypeAttr>:$effective_hidden_scale_intermediate
3998  );
3999
4000  let results = (outs AnyTensor:$output);
4001
4002  // TODO(fengliuai): customize printer and parser to not display
4003  // empty region.
4004  let regions = (region AnyRegion:$internal);
4005
4006  let hasOptions = 1;
4007
4008  let hasCanonicalizer = 1;
4009
4010  let verifier = [{ return Verify(*this); }];
4011
4012  let extraClassDeclaration = [{
4013    // StatefulOpInterface:
4014    std::vector<int> GetStatefulOperands() { return {18, 19}; }
4015  }];
4016}
4017
4018// UnidirectionalSequenceLstm op.
4019// TODO(ashwinm): Add constraint to validate the combination of operands
4020// that are valid for hybrid vs fully quantized vs float only semantics
4021def TFL_UnidirectionalSequenceLSTMOp :
4022  TFL_Op<"unidirectional_sequence_lstm",
4023          [LstmMandatoryInputsConstraint,
4024           LstmOptionalPeepholeWeightConstraint,
4025           LstmProjectionWeightBiasConstraint,
4026           LstmCifgInputConstraint,
4027           LstmResultConstraint,
4028           TFL_OperandHasRankAtLeast<0, 2>,    // input
4029           TFL_OperandIsNoneOrHasRank<1, 2>,   // input_to_input_weights
4030           TFL_OperandHasRank<2, 2>,           // input_to_forget_weights
4031           TFL_OperandHasRank<3, 2>,           // input_to_cell_weights
4032           TFL_OperandHasRank<4, 2>,           // input_to_output_weights
4033           TFL_OperandIsNoneOrHasRank<5, 2>,   // recurrent_to_input_weights
4034           TFL_OperandHasRank<6, 2>,           // recurrent_to_forget_weights
4035           TFL_OperandHasRank<7, 2>,           // recurrent_to_cell_weights
4036           TFL_OperandHasRank<8, 2>,           // recurrent_to_output_weights
4037           TFL_OperandIsNoneOrHasRank<9, 1>,   // cell_to_input_weights
4038           TFL_OperandIsNoneOrHasRank<10, 1>,  // cell_to_forget_weights
4039           TFL_OperandIsNoneOrHasRank<11, 1>,  // cell_to_output_weights
4040           TFL_OperandIsNoneOrHasRank<12, 1>,  // input_gate_bias
4041           TFL_OperandHasRank<13, 1>,          // forget_gate_bias
4042           TFL_OperandHasRank<14, 1>,          // cell_gate_bias
4043           TFL_OperandHasRank<15, 1>,          // output_gate_bias
4044           TFL_OperandIsNoneOrHasRank<16, 2>,  // projection_weights
4045           TFL_OperandIsNoneOrHasRank<17, 1>,  // projection_bias
4046           TFL_StatefulOp,
4047           DeclareOpInterfaceMethods<InferTypeOpInterface>
4048          ]> {
4049  let summary = "Unidirectional sequence lstm operator";
4050
4051  let description = [{
4052    A recurrent neural network specified by an LSTM cell. This Op supports
4053    unrolling the input along the time or batch dimensions, and
4054    implements the following operation for
4055    each element in the sequence s = 1...sequence_length:
4056      outputs[s] = state = activation(LSTMOp(inputs[s]))
4057
4058    where LSTMOp is LSTM TF Lite Op and the “activation” is the function passed
4059    as the “fused_activation_function” argument (if not “NONE”).
4060  }];
4061
4062  let arguments = (
4063    ins TFL_FpTensor:$input,
4064
4065    // Weights
4066    TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
4067    TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
4068    TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
4069    TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
4070
4071    // Recurrent weights
4072    TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
4073    TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
4074    TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
4075    TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
4076
4077    // Cell weights
4078    TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_input_weights,
4079    // Optional input
4080    TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_forget_weights,
4081    // Optional input
4082    TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_output_weights,
4083
4084    // Bias
4085    TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
4086    TFL_FpTensor:$forget_gate_bias,
4087    TFL_FpTensor:$cell_bias,
4088    TFL_FpTensor:$output_gate_bias,
4089
4090    // Projection weight and bias
4091    TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
4092    // Optional input
4093    TFL_TensorOfOrNone<[F32]>:$projection_bias,
4094
4095    // Stateful activation and cell states.
4096    TFL_StatefulTensor:$input_activation_state,
4097    TFL_StatefulTensor:$input_cell_state,
4098
4099    // Layer norm coefficients
4100    TFL_TensorOfOrNone<[F32, QI8]>:$input_layer_norm_coefficients,
4101    TFL_TensorOfOrNone<[F32, QI8]>:$forget_layer_norm_coefficients,
4102    TFL_TensorOfOrNone<[F32, QI8]>:$cell_layer_norm_coefficients,
4103    TFL_TensorOfOrNone<[F32, QI8]>:$output_layer_norm_coefficients,
4104
4105    // Attributes
4106    TFL_AFAttr:$fused_activation_function,
4107    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
4108    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
4109    BoolAttr:$time_major,
4110
4111    // Types of the optional intermediate tensors, which exist for fully
4112    // quantized op and hold the ranges of the intermediate tensors.
4113    // The type for intermediate tenssors are be quant.calibrated when imported
4114    // to only store calibrated min, max values. The proper quantization spec is
4115    // determined while going through quantization passes.
4116    OptionalAttr<TypeAttr>:$input_to_input_intermediate,
4117    OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
4118    OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
4119    OptionalAttr<TypeAttr>:$input_to_output_intermediate,
4120    OptionalAttr<TypeAttr>:$effective_hidden_scale_intermediate
4121  );
4122
4123  let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
4124
4125  let hasOptions = 1;
4126
4127  let verifier = [{ return Verify(*this); }];
4128
4129  let extraClassDeclaration = [{
4130    // StatefulOpInterface:
4131    std::vector<int> GetStatefulOperands() { return {18, 19}; }
4132
4133    // Compatiable return types check
4134    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
4135  }];
4136}
4137
4138def BidiLstmMandatoryInputsConstraint : PredOpTrait<
4139  "mandatory operands element types should match",
4140  // TODO(ashwinm): Replace the indices with input tensor names when that
4141  // support is available.
4142  Or<[
4143    TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 19, 20, 21, 23, 24, 25,
4144                       30, 31, 32, 35, 36, 37, 38]>,
4145    Neg<TypeIsPred<"input", F32>>]>>;
4146
4147// TODO(b/172517537): support quantized types
4148def BidiLstmOptionalPeepholeWeightConstraint : PredOpTrait<
4149  "the optional peephole weights should all be specified or none",
4150  TCopVTEtAreSameAt<[9, 10, 11, 26, 27, 28]>>;
4151
4152def BidiLstmProjectionWeightBiasConstraint : PredOpTrait<
4153  "either projection weight must be specified or both projection weight and "
4154  "projection bias must not be specified",
4155   Or<[
4156      And<[TypeIsPred<"fw_projection_weights", NoneType>,
4157           TypeIsPred<"fw_projection_bias", NoneType>,
4158           TypeIsPred<"bw_projection_weights", NoneType>,
4159           TypeIsPred<"bw_projection_bias", NoneType>]>,
4160      And<[
4161        Neg<TypeIsPred<"fw_projection_weights", NoneType>>,
4162        Neg<TypeIsPred<"bw_projection_weights", NoneType>>,
4163     ]>
4164   ]>>;
4165
4166// BidirectionalSequenceLstm op.
4167// TODO(ashwinm): Add constraint to validate the combination of operands
4168// that are valid for hybrid vs fully quantized vs float only semantics
4169def TFL_BidirectionalSequenceLSTMOp :
4170  TFL_Op<"bidirectional_sequence_lstm",
4171          [BidiLstmMandatoryInputsConstraint,
4172           BidiLstmOptionalPeepholeWeightConstraint,
4173           BidiLstmProjectionWeightBiasConstraint,
4174           LstmResultConstraint,
4175           TFL_OperandHasRank<0, 3>,   // input
4176           TFL_OperandHasRank<1, 2>,   // fw_input_to_input_weights
4177           TFL_OperandHasRank<2, 2>,   // fw_input_to_forget_weights
4178           TFL_OperandHasRank<3, 2>,   // fw_input_to_cell_weights
4179           TFL_OperandHasRank<4, 2>,   // fw_input_to_output_weights
4180           TFL_OperandHasRank<5, 2>,   // fw_recurrent_to_input_weights
4181           TFL_OperandHasRank<6, 2>,   // fw_recurrent_to_forget_weights
4182           TFL_OperandHasRank<7, 2>,   // fw_recurrent_to_cell_weights
4183           TFL_OperandHasRank<8, 2>,   // fw_recurrent_to_output_weights
4184           TFL_OperandHasRank<9, 1>,   // fw_cell_to_input_weights
4185           TFL_OperandHasRank<10, 1>,  // fw_cell_to_forget_weights
4186           TFL_OperandHasRank<11, 1>,  // fw_cell_to_output_weights
4187           TFL_OperandHasRank<12, 1>,  // fw_input_gate_bias
4188           TFL_OperandHasRank<13, 1>,  // fw_forget_gate_bias
4189           TFL_OperandHasRank<14, 1>,  // fw_cell_bias
4190           TFL_OperandHasRank<15, 1>,  // fw_output_gate_bias
4191           TFL_OperandHasRank<16, 2>,  // fw_projection_weights
4192           TFL_OperandHasRank<17, 1>,  // fw_projection_bias
4193           TFL_OperandHasRank<18, 2>,  // bw_input_to_input_weights
4194           TFL_OperandHasRank<19, 2>,  // bw_input_to_forget_weights
4195           TFL_OperandHasRank<20, 2>,  // bw_input_to_cell_weights
4196           TFL_OperandHasRank<21, 2>,  // bw_input_to_output_weights
4197           TFL_OperandHasRank<22, 2>,  // bw_recurrent_to_input_weights
4198           TFL_OperandHasRank<23, 2>,  // bw_recurrent_to_forget_weights
4199           TFL_OperandHasRank<24, 2>,  // bw_recurrent_to_cell_weights
4200           TFL_OperandHasRank<25, 2>,  // bw_recurrent_to_output_weights
4201           TFL_OperandHasRank<26, 1>,  // bw_cell_to_input_weights
4202           TFL_OperandHasRank<27, 1>,  // bw_cell_to_forget_weights
4203           TFL_OperandHasRank<28, 1>,  // bw_cell_to_output_weights
4204           TFL_OperandHasRank<29, 1>,  // bw_input_gate_bias
4205           TFL_OperandHasRank<30, 1>,  // bw_forget_gate_bias
4206           TFL_OperandHasRank<31, 1>,  // bw_cell_bias
4207           TFL_OperandHasRank<32, 1>,  // bw_output_gate_bias
4208           TFL_OperandHasRank<33, 2>,  // bw_projection_weights
4209           TFL_OperandHasRank<34, 1>,  // bw_projection_bias
4210           TFL_StatefulOp]> {
4211  let summary = "Bidirectional sequence lstm operator";
4212
4213  let description = [{
4214    Bidirectional lstm is essentially two lstms, one running forward & the
4215    other running backward. And the output is the concatenation of the two
4216    lstms.
4217  }];
4218
4219  let arguments = (
4220    ins TFL_TensorOf<[F32, I8]>:$input,
4221
4222    // Forward LSTM Weights
4223    TFL_TensorOfOrNone<[F32, I8]>:$fw_input_to_input_weights,
4224    TFL_TensorOf<[F32, I8]>:$fw_input_to_forget_weights,
4225    TFL_TensorOf<[F32, I8]>:$fw_input_to_cell_weights,
4226    TFL_TensorOf<[F32, I8]>:$fw_input_to_output_weights,
4227
4228    // Forward Recurrent weights
4229    TFL_TensorOfOrNone<[F32, I8]>:$fw_recurrent_to_input_weights,
4230    TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_forget_weights,
4231    TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_cell_weights,
4232    TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_output_weights,
4233
4234    // Forward Cell weights
4235    TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_input_weights,
4236    // Optional Forward cell weights
4237    TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_forget_weights,
4238    // Optional Forward cell weights
4239    TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_output_weights,
4240
4241    // Forward Bias
4242    TFL_TensorOfOrNone<[F32]>:$fw_input_gate_bias,
4243    TFL_TensorOf<[F32]>:$fw_forget_gate_bias,
4244    TFL_TensorOf<[F32]>:$fw_cell_bias,
4245    TFL_TensorOf<[F32]>:$fw_output_gate_bias,
4246
4247    // Forward Projection weight and bias
4248    TFL_TensorOfOrNone<[F32, I8]>:$fw_projection_weights,
4249    // Forward Optional input
4250    TFL_TensorOfOrNone<[F32]>:$fw_projection_bias,
4251
4252    // Backward LSTM Weights
4253    TFL_TensorOfOrNone<[F32, I8]>:$bw_input_to_input_weights,
4254    TFL_TensorOf<[F32, I8]>:$bw_input_to_forget_weights,
4255    TFL_TensorOf<[F32, I8]>:$bw_input_to_cell_weights,
4256    TFL_TensorOf<[F32, I8]>:$bw_input_to_output_weights,
4257
4258    // Backward Recurrent weights
4259    TFL_TensorOfOrNone<[F32, I8]>:$bw_recurrent_to_input_weights,
4260    TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_forget_weights,
4261    TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_cell_weights,
4262    TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_output_weights,
4263
4264    // Backward Cell weights
4265    TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_input_weights,
4266    // Optional Forward cell weights
4267    TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_forget_weights,
4268    // Optional Forward cell weights
4269    TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_output_weights,
4270
4271    // Backward Bias
4272    TFL_TensorOfOrNone<[F32]>:$bw_input_gate_bias,
4273    TFL_TensorOf<[F32]>:$bw_forget_gate_bias,
4274    TFL_TensorOf<[F32]>:$bw_cell_bias,
4275    TFL_TensorOf<[F32]>:$bw_output_gate_bias,
4276
4277    // Backward Projection weight and bias
4278    TFL_TensorOfOrNone<[F32, I8]>:$bw_projection_weights,
4279    // Backward Optional input
4280    TFL_TensorOfOrNone<[F32]>:$bw_projection_bias,
4281
4282    // Stateful activation and cell states.
4283    TFL_StatefulTensor:$fw_input_activation_state,
4284    TFL_StatefulTensor:$fw_input_cell_state,
4285    TFL_StatefulTensor:$bw_input_activation_state,
4286    TFL_StatefulTensor:$bw_input_cell_state,
4287
4288    // Auxiliary input & weights.
4289    TFL_TensorOfOrNone<[F32, I8]>:$aux_input,
4290    // Auxiliary fw weights.
4291    TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_input_weights,
4292    TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_forget_weights,
4293    TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_cell_weights,
4294    TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_output_weights,
4295    // Auxiliary bw weights.
4296    TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_input_weights,
4297    TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_forget_weights,
4298    TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_cell_weights,
4299    TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_output_weights,
4300
4301    // Attributes
4302    TFL_AFAttr:$fused_activation_function,
4303    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
4304    Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
4305    BoolAttr:$merge_outputs,
4306    BoolAttr:$time_major
4307  );
4308
4309  let results = (outs
4310    AnyTensor:$fw_output,
4311    AnyTensor:$bw_output
4312  );
4313
4314  let hasOptions = 1;
4315
4316  let verifier = [{ return Verify(*this); }];
4317
4318  let extraClassDeclaration = [{
4319    // StatefulOpInterface:
4320    std::vector<int> GetStatefulOperands() { return {35, 36, 37, 38}; }
4321  }];
4322}
4323
4324// UnidirectionalSequenceRNN op.
4325def TFL_UnidirectionalSequenceRNNOp : TFL_Op<"unidirectional_sequence_rnn", [
4326    TFL_OperandHasRank<4, 2>,
4327    PredOpTrait<"input and output must have same element type",
4328      TFL_TCresVTEtIsSameAsOp<0, 0>>,
4329    PredOpTrait<"input and constant value operands must have same element type",
4330      TFL_TCopVTEtAreSameAt<1, 2>>,
4331    TFL_StatefulOp]> {
4332  let summary = "Unidirectional sequence rnn operator";
4333
4334  let description = [{
4335    A recurrent neural network specified by an RNN cell. This Op takes in input
4336    in a format {batch_size, seq_len, input_size} or
4337    {seq_len, batch_size, input_size} if it's time-majored.
4338
4339    It implements the following operation for
4340    each element in the sequence s = 1...sequence_length:
4341      outputs[s] = state = activation(RNNOp(inputs[s]))
4342
4343    where RNNOp is RNNOp TF Lite Op and the “activation” is the function passed
4344    as the “fused_activation_function” argument (if not “NONE”).
4345  }];
4346
4347  let arguments = (
4348    ins TFL_FpTensor:$input,
4349
4350    // Weights
4351    TFL_TensorOf<[F32, QI8]>:$input_to_input_weights,
4352
4353    // Recurrent weights
4354    TFL_TensorOf<[F32, QI8]>:$recurrent_to_input_weights,
4355
4356    // Bias
4357    TFL_FpTensor:$input_gate_bias,
4358
4359    // Hidden state.
4360    TFL_StatefulTensor:$hidden_state,
4361
4362    // Attributes
4363    BoolAttr:$time_major,
4364    TFL_AFAttr:$fused_activation_function
4365  );
4366
4367  let results = (outs TFL_FpTensor:$output);
4368
4369  let hasOptions = 1;
4370
4371  let customOption = "SequenceRNNOptions";
4372
4373  let verifier = [{ return Verify(*this); }];
4374
4375  let extraClassDeclaration = [{
4376    // StatefulOpInterface:
4377    std::vector<int> GetStatefulOperands() { return {4}; }
4378  }];
4379}
4380
4381def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
4382  let summary = "Returns locations of nonzero / true values in a tensor.";
4383
4384  let description = [{
4385This operation returns the coordinates of true elements in `condition`. The
4386coordinates are returned in a 2-D tensor where the first dimension (rows)
4387represents the number of true elements, and the second dimension (columns)
4388represents the coordinates of the true elements. Keep in mind, the shape of
4389the output tensor can vary depending on how many true values there are in
4390`condition`. Indices are output in row-major order.
4391  }];
4392
4393  let arguments = (ins
4394    TFL_BoolTensor:$input
4395  );
4396
4397  let results = (outs
4398    TFL_I64Tensor:$index
4399  );
4400}
4401
4402def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
4403    SameOperandsShape]> {
4404
4405  let summary = "Verifies the numericals of the two operands";
4406
4407  let description = [{
4408    The NumericVerify op is a debugging op to verify the numericals of the two
4409    activations. It is a custom op in TFLite.
4410    If log_if_failed is true, the NumericVerify op calculates statistics on
4411    differences between float and quantized activations, output
4412    logs, set differences to the output tensors, and throws an error if errors
4413    above tolerance exist. If log_if_failed = false, then it doesn't care about
4414    errors.
4415  }];
4416
4417  let arguments = (ins
4418    TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input,
4419    TFL_TensorOf<[F32]>:$ref,
4420
4421    // Attributes
4422    DefaultValuedAttr<F32Attr, "0.1">:$tolerance,
4423    DefaultValuedAttr<BoolAttr, "false">:$log_if_failed
4424  );
4425
4426  let results = (outs TFL_FpTensor:$output);
4427}
4428
4429// SVDF op.
4430def TFL_SVDFOp :
4431  TFL_Op<"svdf", [
4432    PredOpTrait<"the input and result tensor elemental types must be same",
4433      TFL_TCresVTEtIsSameAsOp<0, 0>>,
4434    TFL_StatefulOp,
4435    AccumulatorUniformScale<3, 2, 4>]> {
4436
4437  let summary = "Single value decomposition filter operator";
4438
4439  let description = [{
4440    The SVDF op is a decomposition of a densely connected op into low rank
4441    filters.
4442    For details: https://research.google.com/pubs/pub43813.html
4443                 https://arxiv.org/abs/1812.02802
4444  }];
4445
4446  let arguments = (
4447    ins TFL_TensorOf<[F32, QI8]>:$input,
4448
4449    // Feature Weights.
4450    TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights,
4451
4452    // Time weights
4453    TFL_TensorOf<[F32, QI16]>:$time_weights,
4454
4455    // Bias
4456    TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
4457
4458    // Activation state.
4459    TFL_StatefulTensor:$activation_state,
4460
4461    // Attributes
4462    Confined<I32Attr, [IntPositive]>:$rank,
4463    TFL_AFAttr:$fused_activation_function
4464  );
4465
4466  let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
4467
4468  let hasOptions = 1;
4469
4470  let verifier = [{ return Verify(*this); }];
4471
4472  let extraClassDeclaration = [{
4473    // StatefulOpInterface:
4474    std::vector<int> GetStatefulOperands() { return {4}; }
4475  }];
4476}
4477
4478def TFL_SegmentSumOp: TFL_Op<"segment_sum", [
4479    NoSideEffect,
4480    PredOpTrait<"input and output must have same element type",
4481      TFL_TCresVTEtIsSameAsOp<0, 0>>,
4482    NoQuantizableResult]> {
4483  let summary = "SegmentSum operator";
4484
4485  let description = [{
4486    Computes the sum along segments of a tensor.
4487  }];
4488
4489  let arguments = (ins
4490    TFL_TensorOf<[F32, I32]>:$input,
4491    TFL_I32Tensor:$segment_ids
4492  );
4493  let results = (outs TFL_TensorOf<[F32, I32]>:$output);
4494}
4495
4496def TFL_YieldOp : Op<TFL_Dialect, "yield", [NoSideEffect, Terminator]> {
4497  let summary = "Yield operation";
4498  let description = [{
4499    The "yield" operation represents a return operation within the conditional
4500    and body of structured control flow (e.g., while). The operation takes
4501    variable number of operands and produces no results. The operand number and
4502    types must match the signature of the region that contains the operation.
4503  }];
4504
4505  let arguments = (ins Variadic<AnyType>:$operands);
4506}
4507
4508def TFL_IfOp : Op<TFL_Dialect, "if", [
4509    DeclareOpInterfaceMethods<RegionBranchOpInterface>,
4510    SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects,
4511    NoRegionArguments]> {
4512  let summary = [{if-then-else operation}];
4513
4514  let description = [{
4515    The `tfl.if` operation represents an if-then-else construct for
4516    conditionally executing two regions of code. The operand to an if operation
4517    is a boolean value. For example:
4518
4519    ```mlir
4520    tfl.if %b  {
4521      ...
4522    } else {
4523      ...
4524    }
4525    ```
4526
4527    `tfl.if` may also return results that are defined in its regions. The
4528    values defined are determined by which execution path is taken.
4529
4530    Example:
4531
4532    ```mlir
4533    %x, %y = tfl.if %b -> (tensor<f32>, tensor<f32>) {
4534      %x_true = ...
4535      %y_true = ...
4536      tfl.yield %x_true, %y_true : tensor<f32>, tensor<f32>
4537    } else {
4538      %x_false = ...
4539      %y_false = ...
4540      tfl.yield %x_false, %y_false : tensor<f32>, tensor<f32>
4541    }
4542    ```
4543
4544    `tfl.if` regions are always terminated with "tfl.yield". If "tfl.if"
4545    defines no values, the "tfl.yield" can be left out, and will be inserted
4546    implicitly. Otherwise, it must be explicit.
4547    Also, if "tfl.if" defines one or more values, the 'else' block cannot be
4548    omitted.
4549
4550    Example:
4551
4552    ```mlir
4553    tfl.if %b  {
4554      ...
4555    }
4556    ```
4557  }];
4558
4559  let arguments = (ins TFL_BoolTensor:$cond);
4560  let results = (outs Variadic<AnyTensor>:$results);
4561  let regions = (region SizedRegion<1>:$then_region, AnyRegion:$else_region);
4562
4563  let extraClassDeclaration = [{
4564    OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
4565      Block* body = getBody(0);
4566      return results().empty() ? OpBuilder::atBlockTerminator(body, listener)
4567                               : OpBuilder::atBlockEnd(body, listener);
4568    }
4569    OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
4570      Block* body = getBody(1);
4571      return results().empty() ? OpBuilder::atBlockTerminator(body, listener)
4572                               : OpBuilder::atBlockEnd(body, listener);
4573    }
4574  }];
4575
4576  // Canonicalizer wasn't defined for this one. In practise, we legalize the
4577  // tf.IfOp to scf.If op first and then legalize it to tfl.if to reduce
4578  // code redundancy.
4579}
4580
4581def TFL_WhileOp : Op<TFL_Dialect, "while", [
4582    DeclareOpInterfaceMethods<LoopLikeOpInterface>,
4583    SingleBlockImplicitTerminator<"YieldOp">]> {
4584  let summary = [{While loop}];
4585
4586  let description = [{
4587    output = input; while (cond(output)) { output = body(output) }
4588
4589    While loop where all values are passes through arguments with implicit
4590    capture.
4591
4592    input: A list of input tensors whose types are T.
4593    output: A list of output tensors whose types are T.
4594    cond: A region that takes 'input' and returns a boolean scalar tensor.
4595    body: A region that takes a list of tensors and returns another
4596          list of tensors. Both lists have the same types.
4597  }];
4598
4599  let arguments = (ins
4600    Variadic<AnyTensor>:$input,
4601
4602    // Used to map StatelessWhile and While op defined in TensorFlow to a common
4603    // op.
4604    DefaultValuedAttr<BoolAttr, "false">:$is_stateless
4605  );
4606  let results = (outs Variadic<AnyTensor>:$output);
4607
4608  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
4609
4610  let verifier = [{ return Verify(*this); }];
4611
4612  let hasCanonicalizer = 1;
4613}
4614
4615def TFL_CallOnceOp : TFL_Op<"call_once", []> {
4616  let summary = "Invokes an initialization function";
4617
4618  let description = [{
4619This operation invokes the given initialization function for the session
4620initializer in tf saved model dialect.
4621  }];
4622
4623  let arguments = (ins
4624    StrAttr:$session_init_function
4625  );
4626
4627  let results = (outs);
4628}
4629
4630def TFL_CustomOp : Op<TFL_Dialect, "custom", [
4631  NoSideEffect, NoQuantizableResult]> {
4632  let summary = "Custom op";
4633
4634  let description = [{
4635    A generic op for any TFLite custom operation.
4636
4637    input: A list of inputs in the original op.
4638    custom_code: A string used to identify which exactly this op is, which
4639                 corresponds to operator_codes.custom_code in the flatbuffer.
4640    custom_option: a holder to save the op attributes in bytes fashion.
4641    output: A list of outputs in the original op.
4642  }];
4643
4644  let arguments = (ins
4645    Variadic<TFL_TensorOfOrNone<[AnyType]>>:$input,
4646    StrAttr:$custom_code,
4647    OpaqueBytesAttr:$custom_option
4648  );
4649  let results = (outs Variadic<AnyTensor>:$output);
4650
4651  let verifier = [{ return Verify(*this); }];
4652}
4653
4654def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
4655  RecursiveSideEffects,
4656  NoQuantizableResult,
4657  IsolatedFromAbove,
4658  SingleBlockImplicitTerminator<"YieldOp">,
4659  DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
4660  let summary = "Wrapper Op for TF custom ops.";
4661
4662  let description = [{
4663    A wrapper op around any Custom TF op. These includes ops defined using
4664    custom_opdefs or linked which are not defined in TF dialect.
4665    This Op just wraps the custom op inside a region.
4666    Note #1, this Op will not include TF Lite custom ops defined using CustomOp.
4667    Note #2, this op is just internal representation inside the converter and
4668    are not exposed/exported when the model is exported to Flatbuffer.
4669  }];
4670
4671  let arguments = (ins
4672    Variadic<TFL_TensorOfOrNone<[AnyType]>>:$input
4673  );
4674  let results = (outs Variadic<AnyTensor>:$output);
4675
4676  let regions = (region SizedRegion<1>:$body);
4677
4678  let extraClassDeclaration = [{
4679    // Returns whether the return types are compatible.
4680    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
4681  }];
4682}
4683
4684def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
4685    PredOpTrait<"input and output must have same element type",
4686      TFL_TCresVTEtIsSameAsOp<0, 0>>,
4687    TFL_OperandHasRankAtMost<0, 8>,
4688    TFL_OperandHasRank<1, 1>,
4689    PredOpTrait<"output dimension count must be at most 8",
4690      Or<[TFL_OperandIsUnrankedPred<1>,
4691          TFL_OperandDimIsAtMost<1, 0, 8>]>>,
4692    NoSideEffect]> {
4693  let summary = "Broadcast an array for a compatible shape.";
4694
4695  let description = [{
4696Broadcasting is the process of making arrays to have compatible shapes
4697for arithmetic operations. Two shapes are compatible if for each
4698dimension pair they are either equal or one of them is one. When trying
4699to broadcast a Tensor to a shape, it starts with the trailing dimensions,
4700and works its way forward.
4701
4702For example,
4703
4704>>> x = tf.constant([1, 2, 3])
4705>>> y = tf.broadcast_to(x, [3, 3])
4706>>> print(y)
4707tf.Tensor(
4708    [[1 2 3]
4709     [1 2 3]
4710     [1 2 3]], shape=(3, 3), dtype=int32)
4711
4712In the above example, the input Tensor with the shape of `[1, 3]`
4713is broadcasted to output Tensor with shape of `[3, 3]`.
4714
4715When doing broadcasted operations such as multiplying a tensor
4716by a scalar, broadcasting (usually) confers some time or space
4717benefit, as the broadcasted tensor is never materialized.
4718
4719However, `broadcast_to` does not carry with it any such benefits.
4720The newly-created tensor takes the full memory of the broadcasted
4721shape. (In a graph context, `broadcast_to` might be fused to
4722subsequent operation and then be optimized away, however.)
4723  }];
4724
4725  let arguments = (ins
4726    TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
4727    TFL_I32OrI64Tensor:$shape
4728  );
4729
4730  let results = (outs
4731    TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
4732  );
4733}
4734
4735def TFL_RFFT2dOp : TFL_Op<"rfft2d", [NoSideEffect, NoQuantizableResult]> {
4736  let summary = "2D real-valued fast Fourier transform.";
4737
4738  let description = [{
4739Computes the 2-dimensional discrete Fourier transform of a real-valued signal
4740over the inner-most 2 dimensions of `input`.
4741
4742Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
4743`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
4744of `output`: the zero-frequency term, followed by the `fft_length / 2`
4745positive-frequency terms.
4746
4747Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
4748corresponding dimension of `input`, the dimension is cropped. If it is larger,
4749the dimension is padded with zeros.
4750  }];
4751
4752  let arguments = (ins
4753    TFL_FpTensor:$input,
4754    TFL_I32Tensor:$fft_length
4755  );
4756
4757  let results = (outs
4758    TFL_Complex64Tensor:$output
4759  );
4760}
4761
4762def TFL_VarHandleOp : TFL_Op<"var_handle", []> {
4763  let summary = "Returns a handle to a variable resource from its name.";
4764
4765  let description = [{
4766    Returns a handle for a variable resource from its name.
4767    container: the container this variable is placed in.
4768    shared_name: the name by which this variable is referred to.
4769  }];
4770
4771  let arguments = (ins
4772    DefaultValuedAttr<StrAttr, "">:$container,
4773    DefaultValuedAttr<StrAttr, "">:$shared_name
4774  );
4775
4776  let results = (outs TFL_ResourceTensor:$resource_handle);
4777
4778  let hasOptions = 1;
4779}
4780
4781def TFL_AssignVariableOp : TFL_Op<"assign_variable", []> {
4782  let summary = "Assigns a new value to a variable.";
4783
4784  let description = [{
4785Any ReadVariableOp with a control dependency on this op is guaranteed to return
4786this value or a subsequent newer value of the variable.
4787  }];
4788
4789  let arguments = (ins
4790    TFL_ResourceTensor:$resource_id,
4791    TFL_TensorOf<[F32, F64, I1, UI8, I8, QI8, QUI8, I32, I64, QI16, Complex<F<32>>, Complex<F<64>>]>:$value
4792  );
4793
4794  let results = (outs);
4795}
4796
4797def TFL_ReadVariableOp : TFL_Op<"read_variable", []> {
4798  let summary = "Reads variable value.";
4799
4800  let description = [{
4801Read variable data identified by 'resource_id'.
4802  }];
4803
4804  let arguments = (ins
4805    TFL_ResourceTensor:$resource_id
4806  );
4807
4808  let results = (outs TFL_TensorOf<[F32, F64, I1, UI8, I8, QI8, QUI8, I32, I64, QI16, Complex<F<32>>, Complex<F<64>>]>:$result);
4809}
4810
4811def TFL_Conv3DOp : TFL_Op<"conv_3d", [
4812    NoSideEffect,
4813    NoQuantizableResult,
4814    AccumulatorUniformScale<2, 0, 1>,
4815    TFL_OperandHasRank<0, 5>,
4816    TFL_OperandHasRank<1, 5>,
4817    // Channel dimension in input and filter should match.
4818    TFL_OperandsHaveSameDimsTrait<0, 1, 4, 3>,
4819    PredOpTrait<"input and output must have same element type",
4820      TFL_TCresVTEtIsSameAsOp<0, 0>>,
4821    PredOpTrait<"bias and output must have same element type",
4822      Or<[
4823        TFL_OperandIsNoneType<2>,
4824        TFL_TCresVTEtIsSameAsOp<0, 2>]>>,
4825    PredOpTrait<"bias must has num of elements equals to 4th dim of filter",
4826      Or<[
4827        TFL_OperandIsNoneType<2>,
4828        TFL_NumElementsEqualsDim<2, 1, 4>]>>]> {
4829  let summary = "Convolution 3D operator";
4830
4831  let description = [{
4832    Performs convolution operation on 3D inputs.
4833    Inputs:
4834      `inputs[0]`: required: the input activation tensor
4835      `inputs[1]`: required: the filter weight tensor
4836      `inputs[2]`: optional: the bias tensor
4837  }];
4838
4839  let arguments = (ins
4840    TFL_TensorOf<[F32]>:$input,
4841    TFL_TensorOf<[F32]>:$filter,
4842    TFL_TensorOfOrNone<[F32]>:$bias,
4843    I32Attr:$dilation_d_factor,
4844    I32Attr:$dilation_h_factor,
4845    I32Attr:$dilation_w_factor,
4846    TFL_AFAttr:$fused_activation_function,
4847    TFL_PaddingAttr:$padding,
4848    I32Attr:$stride_d,
4849    I32Attr:$stride_h,
4850    I32Attr:$stride_w
4851  );
4852
4853  let results = (outs TFL_TensorOf<[F32]>:$output);
4854
4855  let hasOptions = 1;
4856
4857  let customOption = "Conv3DOptions";
4858}
4859
4860def TFL_Conv3DTransposeOp : TFL_Op<"conv_3d_transpose", [
4861    NoSideEffect,
4862    NoQuantizableResult,
4863    AccumulatorUniformScale<2, 0, 1>,
4864    TFL_OperandHasRank<0, 1>,
4865    TFL_OperandHasRank<1, 5>,
4866    TFL_OperandHasRank<2, 5>,
4867    TFL_NumElementsTrait<0, 5>,
4868    // Channel dimension in input and filter should match.
4869    TFL_OperandsHaveSameDimsTrait<2, 1, 4, 4>,
4870    PredOpTrait<"input and output must have same element type",
4871      TFL_TCresVTEtIsSameAsOp<0, 2>>,
4872    PredOpTrait<"bias and output must have same element type",
4873      Or<[
4874        TFL_OperandIsNoneType<3>,
4875        TFL_TCresVTEtIsSameAsOp<0, 3>]>>,
4876    PredOpTrait<"bias must has num of elements equals to 4th dim of filter",
4877      Or<[
4878        TFL_OperandIsNoneType<3>,
4879        TFL_NumElementsEqualsDim<3, 1, 4>]>>]> {
4880  let summary = "Transposed Convolution 3D operator";
4881
4882  let description = [{
4883    Performs transposed convolution operation on 3D inputs.
4884    Inputs:
4885      `inputs[0]`: required: the shape of output tensor
4886      `inputs[1]`: required: the filter weight tensor
4887      `inputs[2]`: required: the input activation tensor
4888      `inputs[3]`: optional: the bias tensor
4889  }];
4890
4891  let arguments = (ins
4892    TFL_I32Tensor:$output_shape,
4893    TFL_TensorOf<[F32]>:$filter,
4894    TFL_TensorOf<[F32]>:$input,
4895    TFL_TensorOfOrNone<[F32]>:$bias,
4896    I32Attr:$dilation_d_factor,
4897    I32Attr:$dilation_h_factor,
4898    I32Attr:$dilation_w_factor,
4899    TFL_AFAttr:$fused_activation_function,
4900    TFL_PaddingAttr:$padding,
4901    I32Attr:$stride_d,
4902    I32Attr:$stride_h,
4903    I32Attr:$stride_w
4904  );
4905
4906  let results = (outs TFL_TensorOf<[F32]>:$output);
4907
4908  let hasOptions = 1;
4909
4910  let customOption = "Conv3DOptions";
4911}
4912
4913def TFL_ComplexAbsOp : TFL_Op<"complex_abs", [
4914  NoSideEffect,
4915  SameOperandsAndResultShape]> {
4916  let summary = "Computes the complex absolute value of a tensor.";
4917
4918  let description = [{
4919Given a tensor `x` of complex numbers, this operation returns a tensor of type
4920`float` or `double` that is the absolute value of each element in `x`. All
4921elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute
4922value is computed as \\( \sqrt{a^2 + b^2}\\).
4923  }];
4924
4925  let arguments = (ins
4926    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
4927  );
4928
4929  let results = (outs
4930    TFL_TensorOf<[F32, F64]>:$output
4931  );
4932}
4933
4934def TFL_RealOp : TFL_Op<"real", [
4935  NoSideEffect,
4936  SameOperandsAndResultShape]> {
4937  let summary = "Returns the real part of a complex number.";
4938
4939  let description = [{
4940Given a tensor `input` of complex numbers, this operation returns a tensor of
4941type `float` that is the real part of each element in `input`. All elements in
4942`input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
4943 part returned by this operation and *b* is the imaginary part.
4944  }];
4945
4946  let arguments = (ins
4947    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
4948  );
4949
4950  let results = (outs
4951    TFL_TensorOf<[F32, F64]>:$output
4952  );
4953}
4954
4955def TFL_ImagOp : TFL_Op<"imag", [
4956  NoSideEffect,
4957  SameOperandsAndResultShape]> {
4958  let summary = "Returns the imaginary part of a complex number.";
4959
4960  let description = [{
4961Given a tensor `input` of complex numbers, this operation returns a tensor of
4962type `float` that is the imaginary part of each element in `input`. All
4963elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
4964is the real part and *b* is the imaginary part returned by this operation.
4965  }];
4966
4967  let arguments = (ins
4968    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
4969  );
4970
4971  let results = (outs
4972    TFL_TensorOf<[F32, F64]>:$output
4973  );
4974}
4975
4976def TFL_HashtableOp: TFL_Op<"hashtable", []> {
4977  let summary = "Creates a non-initialized hash table.";
4978  let description = [{
4979This op creates a hash table, specifying the type of its keys and values.
4980Before using the table you will have to initialize it.  After initialization the
4981table will be immutable.
4982  }];
4983
4984  let arguments = (ins
4985    I32Attr:$table_id,
4986    TypeAttr:$key_dtype,
4987    TypeAttr:$value_dtype
4988  );
4989
4990  let results = (outs TFL_ResourceTensor:$out);
4991
4992  let hasOptions = 1;
4993}
4994
4995def TFL_HashtableFindOp: TFL_Op<"hashtable_find", []> {
4996  let summary = "Looks up keys in a table, outputs the corresponding values.";
4997
4998  let description = [{
4999The tensor `keys` must of the same type as the keys of the table.
5000The output `values` is of the type of the table values.
5001
5002The scalar `default_value` is the value output for keys not present in the
5003table. It must also be of the same type as the table values.
5004  }];
5005
5006  let arguments = (ins
5007    TFL_ResourceTensor:$hash_table,
5008    TFL_TensorOf<[I32, TFL_Str, I64]>:$keys,
5009    TFL_TensorOf<[F32, I32, TFL_Str, I64]>:$default_value
5010  );
5011
5012  let results = (outs TFL_TensorOf<[F32, I32, TFL_Str, I64]>:$out);
5013}
5014
5015def TFL_HashtableImportOp: TFL_Op<"hashtable_import", []> {
5016  let summary = [{
5017Replaces the contents of the table with the specified keys and values.
5018  }];
5019
5020  let description = [{
5021The tensor `keys` must be of the same type as the keys of the table.
5022The tensor `values` must be of the type of the table values.
5023  }];
5024
5025  let arguments = (ins
5026    TFL_ResourceTensor:$hash_table,
5027    TFL_TensorOf<[I32, TFL_Str, I64]>:$keys,
5028    TFL_TensorOf<[F32, I32, TFL_Str, I64]>:$values
5029  );
5030
5031  let results = (outs);
5032}
5033
5034
5035def TFL_HashtableSizeOp: TFL_Op<"hashtable_size", []> {
5036  let summary = "Computes the number of elements in the given table.";
5037
5038  let arguments = (ins
5039    TFL_ResourceTensor:$hash_table
5040  );
5041
5042  let results = (outs
5043    TFL_I64Tensor:$out
5044  );
5045}
5046
5047def TFL_BroadcastArgsOp : TFL_Op<"broadcast_args",[
5048    OperandsSameElementTypeConstraintBase<"BroadcastArgs op">,
5049    PredOpTrait<"input and output must have same element type",
5050      TFL_TCresVTEtIsSameAsOp<0, 0>>,
5051    TFL_OperandHasRank<0, 1>,
5052    TFL_OperandHasRank<1, 1>,
5053    NoSideEffect]> {
5054  let summary = "Return the shape of s0 op s1 with broadcast.";
5055
5056  let description = [{
5057Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the
5058broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
5059  }];
5060
5061  let arguments = (ins
5062    TFL_I32OrI64Tensor:$s0,
5063    TFL_I32OrI64Tensor:$s1
5064  );
5065
5066  let results = (outs
5067    TFL_I32OrI64Tensor:$r0
5068  );
5069}
5070
5071#endif // TFL_OPS
5072
5073// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
5074