1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_QUANT_QUANT_TYPES_H_ 10 #define MLIR_DIALECT_QUANT_QUANT_TYPES_H_ 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/Types.h" 18 #include "llvm/Support/MathExtras.h" 19 20 namespace mlir { 21 namespace quant { 22 23 class QuantizedIntegerType; 24 25 namespace detail { 26 27 struct QuantizedTypeStorage; 28 struct AnyQuantizedTypeStorage; 29 struct UniformQuantizedTypeStorage; 30 struct UniformQuantizedPerAxisTypeStorage; 31 struct CalibratedQuantizedTypeStorage; 32 33 } // namespace detail 34 35 /// Enumeration of bit-mapped flags related to quantized types. 36 namespace QuantizationFlags { 37 enum FlagValue { 38 // Indicates that the storage type should be interpreted as a signed 39 // integer. The default is to interpret it as an unsigned value. 40 Signed = 1, 41 }; 42 } // namespace QuantizationFlags 43 44 /// Base class for all quantized types known to this dialect. 45 /// All quantized types have: 46 /// - storageType: The (narrower) numeric type that is being used to 47 /// approximate some expressed type. 48 /// - expressedType: The type that is being approximated. 49 /// 50 /// The base class provides generic support for manipulating the types based 51 /// on these fields. 52 class QuantizedType : public Type { 53 public: 54 using ImplType = detail::QuantizedTypeStorage; 55 using Type::Type; 56 57 /// The maximum number of bits supported for storage types. 58 static constexpr unsigned MaxStorageBits = 32; 59 60 static LogicalResult 61 verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, 62 Type expressedType, int64_t storageTypeMin, 63 int64_t storageTypeMax); 64 65 /// Support method to enable LLVM-style type casting. 66 static bool classof(Type type); 67 68 /// Gets the minimum possible stored by a storageType. storageTypeMin must 69 /// be greater than or equal to this value. getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)70 static int64_t getDefaultMinimumForInteger(bool isSigned, 71 unsigned integralWidth) { 72 if (isSigned) { 73 return llvm::minIntN(integralWidth); 74 } 75 return 0; 76 } 77 78 /// Gets the maximum possible stored by a storageType. storageTypeMax must 79 /// be less than or equal to this value. getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)80 static int64_t getDefaultMaximumForInteger(bool isSigned, 81 unsigned integralWidth) { 82 if (isSigned) { 83 return llvm::maxIntN(integralWidth); 84 } 85 return llvm::maxUIntN(integralWidth); 86 } 87 88 /// Gets the original expressed type that this quantized type approximates. 89 /// Note that this presumes that the quantized type was always derived from 90 /// a floating point type, which in the broadest definition, is not true (i.e. 91 /// it could be some form of integral, fixed type or affine type in its own 92 /// right); however, at the high level, no examples of such usage are 93 /// presently known and the restriction serves some useful purposes (such as 94 /// always being able to reverse a transformation or measure error). In most 95 /// cases, this will be f32. 96 Type getExpressedType() const; 97 98 /// Gets the flags associated with this type. Typically a more specific 99 /// accessor is appropriate. 100 unsigned getFlags() const; 101 102 // Convenience helpers. 103 /// Whether the storage type should be interpreted as a signed quantity 104 /// (true) or an unsigned value (false). isSigned()105 bool isSigned() const { 106 return (getFlags() & QuantizationFlags::Signed) == 107 QuantizationFlags::Signed; 108 } 109 110 /// Gets the underlying type used for to store values. Note that this may 111 /// be signed or unsigned. Use the isSigned() accessor to differentiate. 112 Type getStorageType() const; 113 114 /// The minimum value that storageType can take. 115 int64_t getStorageTypeMin() const; 116 117 /// The maximum value that storageType can take. 118 int64_t getStorageTypeMax() const; 119 120 /// Gets the integral bit width that the underlying storage type can exactly 121 /// represent. For integral storage types, this will just be their width. 122 unsigned getStorageTypeIntegralWidth() const; 123 124 /// Returns whether the candidateExpressedType is a match for this 125 /// QuantizedType. This will be true if the candidate type is either a 126 /// primitive type or a container type whose element type equals this 127 /// QuantizedType's expressed type. 128 /// Examples of compatible candidateExpressedType: 129 /// !quant.uniform<i8:f32, 1.0> =~ f32 130 /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32> 131 bool isCompatibleExpressedType(Type candidateExpressedType); 132 133 /// Returns the element type as a QuantizedType or nullptr if it is not 134 /// a quantized type. If the type is primitive, returns that. If it is a 135 /// container (vector/tensor), return the element type. 136 /// Examples: 137 /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0> 138 /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0> 139 static QuantizedType getQuantizedElementType(Type primitiveOrContainerType); 140 141 /// Casts from a type based on the storageType to a corresponding type based 142 /// on this type (returns nullptr if the cast is not valid). 143 /// Examples: 144 /// i8 -> !quant.uniform<i8:f32, 1.0> 145 /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>> 146 /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>> 147 Type castFromStorageType(Type candidateType); 148 149 /// Casts from a type based on a QuantizedType to a corresponding type based 150 /// on the storageType (returns nullptr if the cast is not valid). 151 /// This is the inverse of castFromStorageType(). 152 static Type castToStorageType(Type quantizedType); 153 154 /// Casts from a type based on the expressedType to a corresponding type based 155 /// on this type (returns nullptr if the cast is not valid). 156 /// Examples: 157 /// f32 -> !quant.uniform<i8:f32, 1.0> 158 /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>> 159 /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>> 160 Type castFromExpressedType(Type candidateType); 161 162 /// Casts from a type based on QuantizedType to a corresponding type based 163 /// on the expressedType (returns nullptr if the cast is not valid). 164 /// This is the inverse of castFromExpressedType. 165 static Type castToExpressedType(Type quantizedType); 166 167 /// Casts from a type based on the expressedType to the equivalent type 168 /// based on storageType by way of this QuantizedType. Equivalent to: 169 /// QuantizedType::castToStorageType(castFromExpressedType(candidateType)) 170 /// (but with validity checks). 171 /// Example (for this = !quant.uniform<i8:f32, 1.0>): 172 /// tensor<4xf32> -> tensor<4xi8> 173 Type castExpressedToStorageType(Type candidateType); 174 175 private: 176 /// Hide the following methods inherited from `Type`. It is almost certainly 177 /// a bug to call them from a `QuantizedType` object. Users should call 178 /// `getStorageType` or `getExpressedType` to get the underlying types 179 /// they want to inspect. 180 using Type::isBF16; 181 using Type::isF16; 182 using Type::isF32; 183 using Type::isF64; 184 using Type::isIndex; 185 using Type::isInteger; 186 }; 187 188 /// A quantized type that maps storage to/from expressed types in an 189 /// unspecified way. 190 /// 191 /// Typical syntax: 192 /// quant.any<i8:f32> 193 /// quant.any<i8> 194 /// quant.any<i8<-16,15>> 195 /// 196 /// Note that for the any type, the expressed type is optional. 197 class AnyQuantizedType 198 : public Type::TypeBase<AnyQuantizedType, QuantizedType, 199 detail::AnyQuantizedTypeStorage> { 200 public: 201 using Base::Base; 202 203 /// Gets an instance of the type with all parameters specified but not 204 /// checked. 205 static AnyQuantizedType get(unsigned flags, Type storageType, 206 Type expressedType, int64_t storageTypeMin, 207 int64_t storageTypeMax); 208 209 /// Gets an instance of the type with all specified parameters checked. 210 /// Returns a nullptr convertible type on failure. 211 static AnyQuantizedType getChecked(unsigned flags, Type storageType, 212 Type expressedType, int64_t storageTypeMin, 213 int64_t storageTypeMax, Location location); 214 215 /// Verifies construction invariants and issues errors/warnings. 216 static LogicalResult 217 verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, 218 Type expressedType, int64_t storageTypeMin, 219 int64_t storageTypeMax); 220 }; 221 222 /// Represents a family of uniform, quantized types. 223 /// 224 /// Each instance of this type expresses a mapping between real values (most 225 /// often expressed in floating point f32) and quantized values (either fixed 226 /// point or affine). 227 /// 228 /// The relationship is: 229 /// real_value = scale * (quantized_value - zero_point) 230 /// 231 /// It is used as part of high level graph transformations that have the goal 232 /// of re-expressing parts of a computation in terms of this common form for 233 /// more efficient execution at runtime. In addition, it is designed to be 234 /// expressive enough to facilitate lowering to precise types and operations 235 /// in target hardware. 236 /// 237 /// As a high-level type, focused on intermediate passes, this type holds 238 /// opinions consistent with high-level usage. If lowering math kernels below 239 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific 240 /// instruction sets), it is expected that the information expressed here 241 /// will be used to drive low level codegen and target specific type selection, 242 /// but this type will likely be erased in the process. 243 /// 244 /// Syntax synopsis: 245 /// Per-layer, all parameters expressed: 246 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}> 247 /// Per-layer, optional parameters omitted: 248 /// !quant<uniform[StorageType]{Scale}> 249 /// 250 /// StorageType: 'i'|'u' NumBits 251 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 252 /// Scale: A legal double value 253 /// ZeroPoint: An integer value 254 class UniformQuantizedType 255 : public Type::TypeBase<UniformQuantizedType, QuantizedType, 256 detail::UniformQuantizedTypeStorage> { 257 public: 258 using Base::Base; 259 260 /// Gets an instance of the type with all parameters specified but not 261 /// checked. 262 static UniformQuantizedType get(unsigned flags, Type storageType, 263 Type expressedType, double scale, 264 int64_t zeroPoint, int64_t storageTypeMin, 265 int64_t storageTypeMax); 266 267 /// Gets an instance of the type with all specified parameters checked. 268 /// Returns a nullptr convertible type on failure. 269 static UniformQuantizedType 270 getChecked(unsigned flags, Type storageType, Type expressedType, double scale, 271 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, 272 Location location); 273 274 /// Verifies construction invariants and issues errors/warnings. 275 static LogicalResult 276 verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, 277 Type expressedType, double scale, 278 int64_t zeroPoint, int64_t storageTypeMin, 279 int64_t storageTypeMax); 280 281 /// Gets the scale term. The scale designates the difference between the real 282 /// values corresponding to consecutive quantized values differing by 1. 283 double getScale() const; 284 285 /// Gets the storage value corresponding to the real value 0 in the affine 286 /// equation. 287 int64_t getZeroPoint() const; 288 289 // Fixed point values are real numbers divided by a scale. 290 // Currently, only signed storage types are treated as fixed point. 291 // A fixed point value can be obtained from an affine value by subtracting 292 // the zeroPoint. 293 // In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()294 bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } 295 }; 296 297 /// Represents per-axis (also known as per-channel quantization). 298 /// 299 /// Syntax synopsis: 300 /// Per-axis, all parameters expressed: 301 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}> 302 /// Per-axis, optional parameters omitted: 303 /// !quant<uniform[StorageType]{Scale}> 304 /// 305 /// StorageType: 'i'|'u' NumBits 306 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 307 /// QuantizedDim: An integer value 308 /// QuantParams: (Scale ':' ZeroPoint)+ 309 /// Scale: A legal double value 310 /// ZeroPoint: An integer value 311 class UniformQuantizedPerAxisType 312 : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType, 313 detail::UniformQuantizedPerAxisTypeStorage> { 314 public: 315 using Base::Base; 316 317 /// Gets an instance of the type with all parameters specified but not 318 /// checked. 319 static UniformQuantizedPerAxisType 320 get(unsigned flags, Type storageType, Type expressedType, 321 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 322 int32_t quantizedDimension, int64_t storageTypeMin, 323 int64_t storageTypeMax); 324 325 /// Gets an instance of the type with all specified parameters checked. 326 /// Returns a nullptr convertible type on failure. 327 static UniformQuantizedPerAxisType 328 getChecked(unsigned flags, Type storageType, Type expressedType, 329 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 330 int32_t quantizedDimension, int64_t storageTypeMin, 331 int64_t storageTypeMax, Location location); 332 333 /// Verifies construction invariants and issues errors/warnings. 334 static LogicalResult 335 verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, 336 Type expressedType, ArrayRef<double> scales, 337 ArrayRef<int64_t> zeroPoints, 338 int32_t quantizedDimension, 339 int64_t storageTypeMin, int64_t storageTypeMax); 340 341 /// Gets the quantization scales. The scales designate the difference between 342 /// the real values corresponding to consecutive quantized values differing 343 /// by 1. The ith scale corresponds to the ith slice in the 344 /// quantized_dimension. 345 ArrayRef<double> getScales() const; 346 347 /// Gets the storage values corresponding to the real value 0 in the affine 348 /// equation. The ith zero point corresponds to the ith slice in the 349 /// quantized_dimension. 350 ArrayRef<int64_t> getZeroPoints() const; 351 352 /// Specifies the dimension of the Tensor's shape that the scales and 353 /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] 354 /// with quantization params: 355 /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1 356 /// will be quantized across the second dimension of t. 357 /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 358 /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 359 /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 360 int32_t getQuantizedDimension() const; 361 362 /// Fixed point values are real numbers divided by a scale. 363 /// Currently, only signed storage types are treated as fixed point. 364 /// A fixed point value can be obtained from an affine value by subtracting 365 /// the zeroPoint. 366 /// In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()367 bool isFixedPoint() const { 368 if (!isSigned()) 369 return false; 370 return llvm::all_of(getZeroPoints(), 371 [](int64_t zeroPoint) { return zeroPoint != 0; }); 372 } 373 }; 374 375 /// A quantized type that infers its range from given min/max values. 376 /// 377 /// Typical syntax: 378 /// quant.calibrated<f32<-0.922,0.981>> 379 class CalibratedQuantizedType 380 : public Type::TypeBase<CalibratedQuantizedType, QuantizedType, 381 detail::CalibratedQuantizedTypeStorage> { 382 public: 383 using Base::Base; 384 385 /// Gets an instance of the type with all parameters specified but not 386 /// checked. 387 static CalibratedQuantizedType get(Type expressedType, double min, 388 double max); 389 390 /// Gets an instance of the type with all specified parameters checked. 391 /// Returns a nullptr convertible type on failure. 392 static CalibratedQuantizedType getChecked(Type expressedType, double min, 393 double max, Location location); 394 395 /// Verifies construction invariants and issues errors/warnings. 396 static LogicalResult verifyConstructionInvariants(Location loc, 397 Type expressedType, 398 double min, double max); 399 double getMin() const; 400 double getMax() const; 401 }; 402 403 } // namespace quant 404 } // namespace mlir 405 406 #endif // MLIR_DIALECT_QUANT_QUANT_TYPES_H_ 407