1 //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef TYPE_DETAIL_H_ 10 #define TYPE_DETAIL_H_ 11 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/TypeSupport.h" 14 #include "mlir/IR/Types.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/Hashing.h" 17 #include "llvm/ADT/bit.h" 18 19 namespace mlir { 20 namespace quant { 21 namespace detail { 22 23 struct QuantizedTypeStorage : public mlir::TypeStorage { QuantizedTypeStorageQuantizedTypeStorage24 QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, 25 int64_t storageTypeMin, int64_t storageTypeMax) 26 : flags(flags), storageType(storageType), expressedType(expressedType), 27 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} 28 29 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. 30 unsigned flags; 31 32 // Integral type for the storage point representation. 33 Type storageType; 34 35 // Floating point type that the quantized type approximates. 36 Type expressedType; 37 38 // The minimum value storageType can take. 39 int64_t storageTypeMin; 40 41 // The maximum value storageType can take. 42 int64_t storageTypeMax; 43 }; 44 45 struct AnyQuantizedTypeStorage : public QuantizedTypeStorage { 46 struct KeyTy { KeyTyAnyQuantizedTypeStorage::KeyTy47 KeyTy(unsigned flags, Type storageType, Type expressedType, 48 int64_t storageTypeMin, int64_t storageTypeMax) 49 : flags(flags), storageType(storageType), expressedType(expressedType), 50 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} 51 unsigned flags; 52 Type storageType; 53 Type expressedType; 54 int64_t storageTypeMin; 55 int64_t storageTypeMax; 56 57 // Check for equality of two structures that share KeyTy data members 58 // (by name). 59 template <typename T, typename U> genericIsEqualAnyQuantizedTypeStorage::KeyTy60 static bool genericIsEqual(const T &lhs, const U &rhs) { 61 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && 62 lhs.expressedType == rhs.expressedType && 63 lhs.storageTypeMin == rhs.storageTypeMin && 64 lhs.storageTypeMax == rhs.storageTypeMax; 65 } 66 67 bool operator==(const KeyTy &other) const { 68 return genericIsEqual(*this, other); 69 } 70 getHashValueAnyQuantizedTypeStorage::KeyTy71 unsigned getHashValue() const { 72 return llvm::hash_combine(flags, storageType, expressedType, 73 storageTypeMin, storageTypeMax); 74 } 75 }; 76 AnyQuantizedTypeStorageAnyQuantizedTypeStorage77 AnyQuantizedTypeStorage(const KeyTy &key) 78 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, 79 key.storageTypeMin, key.storageTypeMax) {} 80 81 bool operator==(const KeyTy &key) const { 82 return KeyTy::genericIsEqual(*this, key); 83 } 84 85 /// Construction. constructAnyQuantizedTypeStorage86 static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, 87 const KeyTy &key) { 88 return new (allocator.allocate<AnyQuantizedTypeStorage>()) 89 AnyQuantizedTypeStorage(key); 90 } 91 hashKeyAnyQuantizedTypeStorage92 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } 93 }; 94 95 struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { 96 struct KeyTy { KeyTyUniformQuantizedTypeStorage::KeyTy97 KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, 98 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) 99 : flags(flags), storageType(storageType), expressedType(expressedType), 100 scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin), 101 storageTypeMax(storageTypeMax) {} 102 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. 103 unsigned flags; 104 105 // Integral type for the storage point representation. 106 Type storageType; 107 108 // Floating point type that the quantized type approximates. 109 Type expressedType; 110 111 double scale; 112 int64_t zeroPoint; 113 int64_t storageTypeMin; 114 int64_t storageTypeMax; 115 116 // Check for equality of two structures that share KeyTy data members 117 // (by name). 118 template <typename T, typename U> genericIsEqualUniformQuantizedTypeStorage::KeyTy119 static bool genericIsEqual(const T &lhs, const U &rhs) { 120 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && 121 lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale && 122 lhs.zeroPoint == rhs.zeroPoint && 123 lhs.storageTypeMin == rhs.storageTypeMin && 124 lhs.storageTypeMax == rhs.storageTypeMax; 125 } 126 127 bool operator==(const KeyTy &other) const { 128 return genericIsEqual(*this, other); 129 } 130 getHashValueUniformQuantizedTypeStorage::KeyTy131 unsigned getHashValue() const { 132 int64_t scaleBits = llvm::bit_cast<int64_t>(scale); 133 return llvm::hash_combine(flags, storageType, expressedType, scaleBits, 134 zeroPoint, storageTypeMin, storageTypeMax); 135 } 136 }; 137 UniformQuantizedTypeStorageUniformQuantizedTypeStorage138 UniformQuantizedTypeStorage(const KeyTy &key) 139 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, 140 key.storageTypeMin, key.storageTypeMax), 141 scale(key.scale), zeroPoint(key.zeroPoint) {} 142 143 bool operator==(const KeyTy &key) const { 144 return KeyTy::genericIsEqual(*this, key); 145 } 146 147 /// Construction. constructUniformQuantizedTypeStorage148 static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, 149 const KeyTy &key) { 150 return new (allocator.allocate<UniformQuantizedTypeStorage>()) 151 UniformQuantizedTypeStorage(key); 152 } 153 hashKeyUniformQuantizedTypeStorage154 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } 155 156 double scale; 157 int64_t zeroPoint; 158 }; 159 160 struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { 161 struct KeyTy { KeyTyUniformQuantizedPerAxisTypeStorage::KeyTy162 KeyTy(unsigned flags, Type storageType, Type expressedType, 163 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 164 int32_t quantizedDimension, int64_t storageTypeMin, 165 int64_t storageTypeMax) 166 : flags(flags), storageType(storageType), expressedType(expressedType), 167 scales(scales), zeroPoints(zeroPoints), 168 quantizedDimension(quantizedDimension), 169 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} 170 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. 171 unsigned flags; 172 173 // Integral type for the storage point representation. 174 Type storageType; 175 176 // Floating point type that the quantized type approximates. 177 Type expressedType; 178 179 ArrayRef<double> scales; 180 ArrayRef<int64_t> zeroPoints; 181 int32_t quantizedDimension; 182 int64_t storageTypeMin; 183 int64_t storageTypeMax; 184 getScalesUniformQuantizedPerAxisTypeStorage::KeyTy185 ArrayRef<double> getScales() const { return scales; } 186 getZeroPointsUniformQuantizedPerAxisTypeStorage::KeyTy187 ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; } 188 189 // Check for equality of two structures that share KeyTy data members 190 // (by name). 191 template <typename T, typename U> genericIsEqualUniformQuantizedPerAxisTypeStorage::KeyTy192 static bool genericIsEqual(const T &lhs, const U &rhs) { 193 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && 194 lhs.expressedType == rhs.expressedType && 195 lhs.getScales() == rhs.getScales() && 196 lhs.getZeroPoints() == rhs.getZeroPoints() && 197 lhs.quantizedDimension == rhs.quantizedDimension && 198 lhs.storageTypeMin == rhs.storageTypeMin && 199 lhs.storageTypeMax == rhs.storageTypeMax; 200 } 201 202 bool operator==(const KeyTy &other) const { 203 return genericIsEqual(*this, other); 204 } 205 getHashValueUniformQuantizedPerAxisTypeStorage::KeyTy206 unsigned getHashValue() const { 207 int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data()); 208 ArrayRef<int64_t> scalesBits(scalesCast, scales.size()); 209 return llvm::hash_combine( 210 flags, storageType, expressedType, 211 llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()), 212 llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()), 213 storageTypeMin, storageTypeMax); 214 } 215 }; 216 217 // We pass scales and zeroPoints in directly rather than relying on KeyTy 218 // because we have to create new reallocated versions in `construct` below. UniformQuantizedPerAxisTypeStorageUniformQuantizedPerAxisTypeStorage219 UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales, 220 ArrayRef<int64_t> zeroPoints) 221 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, 222 key.storageTypeMin, key.storageTypeMax), 223 scaleElements(scales.data()), zeroPointElements(zeroPoints.data()), 224 quantParamsSize(scales.size()), 225 quantizedDimension(key.quantizedDimension) {} 226 227 bool operator==(const KeyTy &key) const { 228 return KeyTy::genericIsEqual(*this, key); 229 } 230 231 /// Construction. 232 static UniformQuantizedPerAxisTypeStorage * constructUniformQuantizedPerAxisTypeStorage233 construct(TypeStorageAllocator &allocator, const KeyTy &key) { 234 ArrayRef<double> scales = allocator.copyInto(key.scales); 235 ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints); 236 return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>()) 237 UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints); 238 } 239 hashKeyUniformQuantizedPerAxisTypeStorage240 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } 241 getScalesUniformQuantizedPerAxisTypeStorage242 ArrayRef<double> getScales() const { 243 return ArrayRef<double>(scaleElements, quantParamsSize); 244 } 245 getZeroPointsUniformQuantizedPerAxisTypeStorage246 ArrayRef<int64_t> getZeroPoints() const { 247 return ArrayRef<int64_t>(zeroPointElements, quantParamsSize); 248 } 249 250 const double *scaleElements; 251 const int64_t *zeroPointElements; 252 unsigned quantParamsSize; 253 int32_t quantizedDimension; 254 }; 255 256 struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { 257 struct KeyTy { KeyTyCalibratedQuantizedTypeStorage::KeyTy258 KeyTy(Type expressedType, double min, double max) 259 : expressedType(expressedType), min(min), max(max) {} 260 // Floating point type that the quantized type approximates. 261 Type expressedType; 262 263 double min; 264 double max; 265 266 // Check for equality of two structures that share KeyTy data members 267 // (by name). 268 template <typename T, typename U> genericIsEqualCalibratedQuantizedTypeStorage::KeyTy269 static bool genericIsEqual(const T &lhs, const U &rhs) { 270 return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min && 271 lhs.max == rhs.max; 272 } 273 274 bool operator==(const KeyTy &other) const { 275 return genericIsEqual(*this, other); 276 } 277 getHashValueCalibratedQuantizedTypeStorage::KeyTy278 unsigned getHashValue() const { 279 int64_t minBits = llvm::bit_cast<double>(min); 280 int64_t maxBits = llvm::bit_cast<double>(max); 281 return llvm::hash_combine(expressedType, minBits, maxBits); 282 } 283 }; 284 CalibratedQuantizedTypeStorageCalibratedQuantizedTypeStorage285 CalibratedQuantizedTypeStorage(const KeyTy &key) 286 : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0), 287 min(key.min), max(key.max) {} 288 289 bool operator==(const KeyTy &key) const { 290 return KeyTy::genericIsEqual(*this, key); 291 } 292 293 /// Construction. 294 static CalibratedQuantizedTypeStorage * constructCalibratedQuantizedTypeStorage295 construct(TypeStorageAllocator &allocator, const KeyTy &key) { 296 return new (allocator.allocate<CalibratedQuantizedTypeStorage>()) 297 CalibratedQuantizedTypeStorage(key); 298 } 299 hashKeyCalibratedQuantizedTypeStorage300 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } 301 302 double min; 303 double max; 304 }; 305 306 } // namespace detail 307 } // namespace quant 308 } // namespace mlir 309 310 #endif // TYPE_DETAIL_H_ 311