1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 Copyright 2022 The StableHLO Authors. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15==============================================================================*/ 16 17#ifndef STABLEHLO_DIALECT_BASE 18#define STABLEHLO_DIALECT_BASE 19 20include "mlir/Dialect/Quant/QuantOpsBase.td" 21include "mlir/Interfaces/InferTypeOpInterface.td" 22include "mlir/IR/AttrTypeBase.td" 23include "mlir/IR/OpBase.td" 24 25//===----------------------------------------------------------------------===// 26// HLO type definitions. 27//===----------------------------------------------------------------------===// 28 29def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; 30 31// TODO(hinsu): Use signed integers instead of signless integer which is being 32// used for legacy reasons. 33def HLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>; 34def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>; 35def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; 36 37def HLO_Float : AnyTypeOf<[F16, F32, F64, BF16]>; 38def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>; 39 40def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>; 41 42//===----------------------------------------------------------------------===// 43// Quantized element type definitions. 44//===----------------------------------------------------------------------===// 45 46// TODO(b/230381284): Upstream width-specific uniform quantized element types. 47class UniformQuantizedSignedInt<int width> 48 : Type<Or<[ 49 And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">, 50 CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" # 51 ".getStorageTypeIntegralWidth() == " # width>, 52 CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" # 53 ".isSigned()">]>, 54 And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedPerAxisType>()">, 55 CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" # 56 ".getStorageTypeIntegralWidth() == " # width>, 57 CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" # 58 ".isSigned()">]>]>, 59 "QI" # width # " type"> { 60 string name = "UniformQuantizedSignedInt"; 61 int bitwidth = width; 62} 63 64class UniformQuantizedUnsignedInt<int width> 65 : Type<Or<[ 66 And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">, 67 CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" # 68 ".getStorageTypeIntegralWidth() == " # width>, 69 CPred<"!$_self.cast<mlir::quant::UniformQuantizedType>()" # 70 ".isSigned()">]>, 71 And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedPerAxisType>()">, 72 CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" # 73 ".getStorageTypeIntegralWidth() == " # width>, 74 CPred<"!$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" # 75 ".isSigned()">]>]>, 76 "QUI" # width # " type"> { 77 string name = "UniformQuantizedUnsignedInt"; 78 int bitwidth = width; 79} 80 81class UniformQuantizedSignedIntOfWidths<list<int> widths> : 82 AnyTypeOf<!foreach(w, widths, UniformQuantizedSignedInt<w>), 83 !interleave(widths, "/") # "-bit uniform quantized signed " # 84 "integer">; 85 86class UniformQuantizedUnsignedIntOfWidths<list<int> widths> : 87 AnyTypeOf<!foreach(w, widths, UniformQuantizedUnsignedInt<w>), 88 !interleave(widths, "/") # "-bit uniform quantized unsigned " # 89 "integer">; 90 91// Integer-based uniform quantized types. The definitions can be used to specify 92// operand's tensor types. 93def HLO_QuantizedSignedInt : UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>; 94def HLO_QuantizedUnsignedInt : UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>; 95def HLO_QuantizedInt : 96 AnyTypeOf<[HLO_QuantizedSignedInt, HLO_QuantizedUnsignedInt]>; 97 98// The broadcasting dimensions correspond to a tuple that describes how a 99// smaller rank shape is broadcast into a larger rank shape. For example, 100// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means 101// matching the matrix to dimensions 1 and 2 of the cuboid. 102defvar BroadcastDimAttr = I64ElementsAttr; 103 104// Token type. 105def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">; 106 107// Any integer tensor types 108def HLO_IntTensor : TensorOf<[HLO_Int]>; 109 110// Any integer tensor type with rank 0 (i.e. representing a single integer). 111def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; 112 113// Any floating-point tensor types 114def HLO_FpTensor : TensorOf<[HLO_Float]>; 115 116// 32 or 64 bits floating-point tensor types 117def HLO_Fp32Or64Tensor : TensorOf<[HLO_Float32Or64]>; 118 119// Any quantized integer tensor types 120def HLO_QuantizedIntTensor : TensorOf<[HLO_QuantizedInt]>; 121 122def HLO_PredTensor : TensorOf<[HLO_Pred]>; 123 124def HLO_Tensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>; 125 126def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; 127 128def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; 129 130def HLO_TensorOrToken : AnyTypeOf<[HLO_Tensor, HLO_Token]>; 131 132def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; 133 134def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Int]>; 135 136// Dynamic representation of a shape vector as a tensor. 137def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; 138 139// In general, static shaped tensor constraints should be avoided unless 140// it is for a legacy op which is only correct with static shapes. 141def HLO_StaticShapeTensor : StaticShapeTensorOf<[ 142 HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>; 143 144//===----------------------------------------------------------------------===// 145// HLO combined type definitions. 146//===----------------------------------------------------------------------===// 147 148// Any integer or floating-point tensor types 149def HLO_IntOrFpTensor : TensorOf<[HLO_Int, HLO_Float]>; 150 151// Any integer or predicate tensor types 152def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; 153 154// Any floating-point or complex tensor types 155def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>; 156 157// Any int, floating-point or complex tensor types 158def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>; 159 160// Any pred, int or floating-point tensor types 161def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>; 162 163//===----------------------------------------------------------------------===// 164// HLO traits 165//===----------------------------------------------------------------------===// 166 167class HLO_NativeOpTrait<string name> : NativeOpTrait<name> { 168 let cppNamespace = "::mlir::hlo::OpTrait"; 169} 170 171// An operation that is essentially element-wise but may implement broadcasting 172// semantics. 173def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; 174 175// Op has pairwise operand and result type matching: the number of operands 176// must be equal to the number of results and the type of ith operand must 177// match the type of ith result. 178// TODO(b/195086460) Promote this to be an mlir trait and remove it here. 179def HLO_PairwiseSameOperandAndResultType : 180 HLO_NativeOpTrait<"PairwiseSameOperandAndResultType">; 181 182// Op has operand and result types compatible with each other according to 183// the rules implemented in isCompatibleForHloTypeInference, which account for 184// special properties dynamism, quantization and sparsity. 185def HLO_CompatibleOperandsAndResultType : TraitList< 186 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that 187 // support quantization or sparsity. 188 [ 189 InferTypeOpInterface, 190 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>, 191 HLO_NativeOpTrait<"CompatibleOperandsAndResultType"> 192 ]>; 193 194def HLO_BoundedAttrInterface : AttrInterface<"BoundedAttrInterface"> { 195 let cppNamespace = "::mlir::hlo"; 196 197 let description = [{ 198 This interface is used for attributes that carry bounds for dimension sizes 199 of an accompanying shaped type, e.g. when the attribute represents a 200 RankedTensorType::getEncoding. 201 The number of bounds is expected to be the same as the number of dimensions 202 in the accompanying shaped type. 203 For a static dimension, the corresponding bound is ShapedType::kDynamicSize. 204 For a dynamic dimension, the corresponding bound is either known and is 205 a non-negative number or unknown and is ShapedType::kDynamicSize. 206 }]; 207 208 let methods = [InterfaceMethod< 209 "Get the attribute's bounds", 210 "::llvm::ArrayRef<int64_t>", "getBounds" 211 >]; 212} 213 214#endif // STABLEHLO_DIALECT_BASE 215