1 //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 10 #define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 11 12 #include "mlir/Dialect/Quant/QuantTypes.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Types.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/APInt.h" 17 #include "llvm/ADT/APSInt.h" 18 19 namespace mlir { 20 namespace quant { 21 22 /// Performs type conversion from an arbitrary input type to a type 23 /// that is expressed by a QuantizedType. 24 /// 25 /// This handles cases where the inputType is a supported primitive type 26 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported 27 /// elemental type. 28 /// 29 /// Since conversion often involves introspecting some attributes of the 30 /// input type in order to determine how to represent it, this is a two step 31 /// process. 32 struct ExpressedToQuantizedConverter { 33 /// Creates a converter for the given input type. 34 static const ExpressedToQuantizedConverter forInputType(Type inputType); 35 36 /// Converts the inputType to be based on the given elemental type, 37 /// returning the new type (or nullptr and emit an error on failure). 38 Type convert(QuantizedType elementalType) const; 39 40 /// Whether the conversion is legal. 41 explicit operator bool() const { return (bool)expressedType; } 42 43 /// The input type that is being converted from. 44 /// This may be an elemental or composite type. 45 const Type inputType; 46 47 /// Supported, elemental expressed type (i.e. f32). 48 /// Will be nullptr if conversion is not supported. 49 const Type expressedType; 50 }; 51 52 /// Reference implementation of converting between real numbers and values 53 /// represented by a UniformQuantizedType. 54 /// Note that this is not expected to be speedy and may be superseded eventually 55 /// by a more optimal implementation. 56 /// Also, the interface assumes that quantization is done per-layer and will 57 /// need to be wider for various per-channel schemes. As such, this is a 58 /// placeholder. 59 class UniformQuantizedValueConverter { 60 public: UniformQuantizedValueConverter(UniformQuantizedType uniformType)61 explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType) 62 : UniformQuantizedValueConverter( 63 uniformType.getScale(), 64 static_cast<double>(uniformType.getZeroPoint()), 65 static_cast<double>(uniformType.getStorageTypeMin()), 66 static_cast<double>(uniformType.getStorageTypeMax()), 67 uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { 68 assert(uniformType.getExpressedType().isa<FloatType>()); 69 assert(uniformType.getStorageType().isSignlessInteger()); 70 } 71 UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)72 UniformQuantizedValueConverter(double scale, double zeroPoint, 73 double clampMin, double clampMax, 74 uint32_t storageBitWidth, bool isSigned) 75 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), 76 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), 77 clampMinDouble(clampMin), clampMaxDouble(clampMax), 78 storageBitWidth(storageBitWidth), isSigned(isSigned), 79 roundMode(APFloat::rmNearestTiesToAway) {} 80 UniformQuantizedValueConverter(double scale,double zeroPoint,APFloat clampMin,APFloat clampMax,uint32_t storageBitWidth,bool isSigned)81 UniformQuantizedValueConverter(double scale, double zeroPoint, 82 APFloat clampMin, APFloat clampMax, 83 uint32_t storageBitWidth, bool isSigned) 84 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), 85 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), 86 clampMinDouble(clampMin.convertToDouble()), 87 clampMaxDouble(clampMax.convertToDouble()), 88 storageBitWidth(storageBitWidth), isSigned(isSigned), 89 roundMode(APFloat::rmNearestTiesToAway) {} 90 quantizeFloatToInt(APFloat expressedValue)91 virtual APInt quantizeFloatToInt(APFloat expressedValue) const { 92 // This function is a performance critical code path in quantization 93 // since it runs for each single float parameter value. 94 95 // Specialize f32->u8/i8 case to optimize performance. 96 if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() && 97 storageBitWidth == 8 && 98 roundMode == llvm::APFloatBase::rmNearestTiesToAway) { 99 return quantizeF32ToInt8(expressedValue); 100 } 101 102 bool lossy; 103 expressedValue.convert(scale.getSemantics(), roundMode, &lossy); 104 // fixedpoint = clamp(clampMin, clampMax, ( 105 // roundHalfToEven(expressed / scale) + zeroPoint)) 106 APFloat scaled = (expressedValue / scale); 107 scaled.roundToIntegral(roundMode); 108 scaled.add(zeroPoint, roundMode); 109 APFloat fixedpoint = llvm::minimum(scaled, clampMax); 110 fixedpoint = llvm::maximum(fixedpoint, clampMin); 111 112 llvm::APSInt result(storageBitWidth, !isSigned); 113 fixedpoint.convertToInteger(result, roundMode, &lossy); 114 115 return std::move(result); 116 } 117 quantizeFloatToInt64(APFloat expressedValue)118 int64_t quantizeFloatToInt64(APFloat expressedValue) const { 119 APInt qValue = quantizeFloatToInt(expressedValue); 120 return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); 121 } 122 ~UniformQuantizedValueConverter()123 virtual ~UniformQuantizedValueConverter() {} 124 125 private: 126 // An optimized implementation to quantize f32 to i8/u8 with C++ native 127 // arithmetic. quantizeF32ToInt8(APFloat expressedValue)128 virtual APInt quantizeF32ToInt8(APFloat expressedValue) const { 129 assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle()); 130 assert(storageBitWidth == 8); 131 assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway); 132 133 const float realValue = expressedValue.convertToFloat(); 134 135 const double scaled = realValue / scaleDouble + zeroPointDouble; 136 // Round to nearest integer with halfway cases rounded away from zero. 137 const double scaledRounded = std::round(scaled); 138 const double clamped = 139 std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble); 140 141 uint64_t signlessResult; 142 if (isSigned) { 143 int64_t clampedInt = static_cast<int8_t>(clamped); 144 memcpy(&signlessResult, &clampedInt, sizeof(clampedInt)); 145 } else { 146 signlessResult = static_cast<uint8_t>(clamped); 147 } 148 return APInt(storageBitWidth, signlessResult); 149 } 150 151 // Keep both APFloat and double versions of the quantization parameters 152 // around since they will be used in generic and specialized arithmetic, 153 // respectively. 154 const APFloat scale; 155 const APFloat zeroPoint; 156 const APFloat clampMin; 157 const APFloat clampMax; 158 159 const double scaleDouble; 160 const double zeroPointDouble; 161 const double clampMinDouble; 162 const double clampMaxDouble; 163 164 const uint32_t storageBitWidth; 165 const bool isSigned; 166 const llvm::APFloat::roundingMode roundMode; 167 }; 168 169 /// An utility class to quantize an attribute by the per-axis quantization 170 /// parameters. The size of the quantization dim in the converted elements 171 /// attribute should matche the size of of scales/zeroPoints vectors in the 172 /// quantization parameters. 173 class UniformQuantizedPerAxisValueConverter { 174 public: UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)175 explicit UniformQuantizedPerAxisValueConverter( 176 UniformQuantizedPerAxisType uniformType) 177 : scales(uniformType.getScales()), 178 zeroPoints(uniformType.getZeroPoints()), 179 clampMin(static_cast<double>(uniformType.getStorageTypeMin())), 180 clampMax(static_cast<double>(uniformType.getStorageTypeMax())), 181 storageBitWidth(uniformType.getStorageTypeIntegralWidth()), 182 isSigned(uniformType.isSigned()), 183 quantizationDim(uniformType.getQuantizedDimension()) { 184 assert(uniformType.getExpressedType().isa<FloatType>()); 185 assert(uniformType.getStorageType().isSignlessInteger()); 186 assert(scales.size() == zeroPoints.size()); 187 } 188 189 /// Quantize an Attribute by the quantization parameters. Return nullptr if 190 /// the conversion fails or the input array isn't an ElementsAttr. 191 ElementsAttr convert(Attribute realValue); 192 193 private: 194 /// Quantize an DenseFPElementsAttr by the quantization parameters. 195 DenseElementsAttr convert(DenseFPElementsAttr attr); 196 197 /// Get a uniform converter for the index-th chunk along the quantizationDim. 198 /// All the elements in this chunk is quantized by the returned converter. getPerChunkConverter(int index)199 UniformQuantizedValueConverter getPerChunkConverter(int index) const { 200 UniformQuantizedValueConverter converter(scales[index], zeroPoints[index], 201 clampMin, clampMax, 202 storageBitWidth, isSigned); 203 return converter; 204 } 205 206 const ArrayRef<double> scales; 207 const ArrayRef<int64_t> zeroPoints; 208 const APFloat clampMin; 209 const APFloat clampMax; 210 const uint32_t storageBitWidth; 211 const bool isSigned; 212 int32_t quantizationDim; 213 }; 214 215 } // namespace quant 216 } // namespace mlir 217 218 #endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ 219