1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15include "mlir/IR/OpBase.td" 16include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" 17include "mlir/IR/PatternBase.td" 18 19def DenseElementsAttr : ElementsAttrBase< 20 CPred<"$_self.isa<DenseElementsAttr>()">, 21 "non-opaque constant tensor">; 22 23// Checks if the data format is "NHWC". 24def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "\"NHWC\"">; 25 26// Checks if the op is constant op. 27def IsConstTensor : Constraint<CPred<"dyn_cast_or_null<TF::ConstOp>($0.getDefiningOp())">>; 28 29// Checks if the element value has a float type. 30def IsFloatElementsAttr : ElementsAttrBase< 31 CPred<"$_self.isa<ElementsAttr>() && " 32 "getElementTypeOrSelf($_self.cast<ElementsAttr>().getType()).isa<FloatType>()">, 33 "float constant tensor">; 34 35// Checks if the boolean value is false. 36def IsFalseBoolAttr : AttrConstraint< 37 CPred<"!$_self.cast<BoolAttr>().getValue()">>; 38 39// Checks if the value has only one user. 40def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; 41 42// Gets the type of a value. 43def GetValueType : NativeCodeCall<"$0.getType()">; 44 45// Checks if the value has the type of int8. 46def IsInt8ElementType : Constraint< 47 CPred<"getElementTypeOrSelf($0).isInteger(8)">>; 48 49// Checks if the value has the type of int32. 50def IsInt32ElementType : Constraint< 51 CPred<"getElementTypeOrSelf($0).isInteger(32)">>; 52 53// Checks if the value has the type of float32. 54def IsF32ElementType : Constraint< 55 CPred<"getElementTypeOrSelf($0).isF32()">>; 56 57// Checks if the value has static shape. 58def HasStaticShapeConstraint : Constraint<CPred<"HasStaticShape($0)">>; 59 60// Checks if the value has static shape at given dims. 61class HasStaticShapeAtDimsConstraint<string dims> : Constraint< 62 CPred<"HasStaticShapeAtDims($0, {"# dims #"})">>; 63 64// The rewrite rule cannot replace a value with itself, so we work around 65// by cloning the root op to replicate that value. The old op will get folded. 66def CloningOpResult : NativeCodeCall< 67 "$_builder.clone(*op0)->getOpResult(0)">; 68 69// Same as CloningOpResult but is used for ops with multiple results. 70class CloningOpResults<int returns> : NativeCodeCall< 71 "$_builder.clone(*op0)->getOpResults()", returns>; 72 73// Creates an 1D array const with float values. 74class Create1DConst<string values> : NativeCodeCall< 75 "Create1DConstValue<float>($_builder, $_loc, "# values #")">; 76 77// Creates a scalar const with float value. 78class CreateScalarConst<string value> : NativeCodeCall< 79 "CreateScalarConstValue<float>($_builder, $_loc, "# value #")">; 80 81// Creates an 1D array const with integer values. 82// TODO(b/239490133): Make the rule name and function name consistent. 83class Create1DIntegerConst<string type, string values> : NativeCodeCall< 84 "Create1DConstValue<"# type #">($_builder, $_loc, "# values #")">; 85 86// Creates a scalar const with integer value. 87class CreateScalarIntegerConst<string type, string value> : NativeCodeCall< 88 "CreateScalarConstValue<"# type #">($_builder, $_loc, "# value #")">; 89 90// Creates an I64 array attribute with given values. 91class CreateI64ArrayAttr<string values> : NativeCodeCall< 92 "$_builder.getI64ArrayAttr("# values #")">; 93 94// Creates a string attribute with given values. 95class CreateStringAttr<string values> : NativeCodeCall< 96 "$_builder.getStringAttr("# values #")">; 97 98// Creates a new F32 type with the same shape as the given value. 99def CloneTypeWithF32ElementType : NativeCodeCall< 100 "CloneTypeWithNewElementType($0.getType(), $_builder.getF32Type())">; 101 102// By default, the generated code uses the `create` method without the output 103// type field. However, for many ops, the output type field is always required. 104class CreateOpWithOutputType<string op_name> : NativeCodeCall< 105 "$_builder.create<"# op_name #">($_loc, $0...)">; 106 107// Checks if the value is a float constant and its splat value is equal to `x`. 108class IsSplatValueEqual<string x> : Constraint<CPred< 109 "IsSplatValueEqual<float>($0, "# x #")">>; 110 111// Checks if two values are float constants and their values are equal. 112def AreSplatValuesEqual : Constraint<CPred< 113 "AreSplatValuesEqual<float>($0, $1)">>; 114 115// Checks if the value is an integer constant and its splat value is equal to x. 116class IsIntSplatValueEqual<string type, string x> : Constraint<CPred< 117 "IsSplatValueEqual<"# type #">($0, "# x #")">>; 118 119// Checks if two values are integer constants and their values are equal. 120class AreIntSplatValuesEqual<string type> : Constraint<CPred< 121 "AreSplatValuesEqual<"# type #">($0, $1)">>; 122