• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
10 #define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
11 
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/APSInt.h"
18 
19 namespace mlir {
20 namespace quant {
21 
22 /// Performs type conversion from an arbitrary input type to a type
23 /// that is expressed by a QuantizedType.
24 ///
25 /// This handles cases where the inputType is a supported primitive type
26 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
27 /// elemental type.
28 ///
29 /// Since conversion often involves introspecting some attributes of the
30 /// input type in order to determine how to represent it, this is a two step
31 /// process.
32 struct ExpressedToQuantizedConverter {
33   /// Creates a converter for the given input type.
34   static const ExpressedToQuantizedConverter forInputType(Type inputType);
35 
36   /// Converts the inputType to be based on the given elemental type,
37   /// returning the new type (or nullptr and emit an error on failure).
38   Type convert(QuantizedType elementalType) const;
39 
40   /// Whether the conversion is legal.
41   explicit operator bool() const { return (bool)expressedType; }
42 
43   /// The input type that is being converted from.
44   /// This may be an elemental or composite type.
45   const Type inputType;
46 
47   /// Supported, elemental expressed type (i.e. f32).
48   /// Will be nullptr if conversion is not supported.
49   const Type expressedType;
50 };
51 
52 /// Reference implementation of converting between real numbers and values
53 /// represented by a UniformQuantizedType.
54 /// Note that this is not expected to be speedy and may be superseded eventually
55 /// by a more optimal implementation.
56 /// Also, the interface assumes that quantization is done per-layer and will
57 /// need to be wider for various per-channel schemes. As such, this is a
58 /// placeholder.
59 class UniformQuantizedValueConverter {
60 public:
UniformQuantizedValueConverter(UniformQuantizedType uniformType)61   explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
62       : UniformQuantizedValueConverter(
63             uniformType.getScale(),
64             static_cast<double>(uniformType.getZeroPoint()),
65             static_cast<double>(uniformType.getStorageTypeMin()),
66             static_cast<double>(uniformType.getStorageTypeMax()),
67             uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
68     assert(uniformType.getExpressedType().isa<FloatType>());
69     assert(uniformType.getStorageType().isSignlessInteger());
70   }
71 
UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)72   UniformQuantizedValueConverter(double scale, double zeroPoint,
73                                  double clampMin, double clampMax,
74                                  uint32_t storageBitWidth, bool isSigned)
75       : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
76         clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
77         clampMinDouble(clampMin), clampMaxDouble(clampMax),
78         storageBitWidth(storageBitWidth), isSigned(isSigned),
79         roundMode(APFloat::rmNearestTiesToAway) {}
80 
UniformQuantizedValueConverter(double scale,double zeroPoint,APFloat clampMin,APFloat clampMax,uint32_t storageBitWidth,bool isSigned)81   UniformQuantizedValueConverter(double scale, double zeroPoint,
82                                  APFloat clampMin, APFloat clampMax,
83                                  uint32_t storageBitWidth, bool isSigned)
84       : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
85         clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
86         clampMinDouble(clampMin.convertToDouble()),
87         clampMaxDouble(clampMax.convertToDouble()),
88         storageBitWidth(storageBitWidth), isSigned(isSigned),
89         roundMode(APFloat::rmNearestTiesToAway) {}
90 
quantizeFloatToInt(APFloat expressedValue)91   virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
92     // This function is a performance critical code path in quantization
93     // since it runs for each single float parameter value.
94 
95     // Specialize f32->u8/i8 case to optimize performance.
96     if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
97         storageBitWidth == 8 &&
98         roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
99       return quantizeF32ToInt8(expressedValue);
100     }
101 
102     bool lossy;
103     expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
104     // fixedpoint = clamp(clampMin, clampMax, (
105     //   roundHalfToEven(expressed / scale) + zeroPoint))
106     APFloat scaled = (expressedValue / scale);
107     scaled.roundToIntegral(roundMode);
108     scaled.add(zeroPoint, roundMode);
109     APFloat fixedpoint = llvm::minimum(scaled, clampMax);
110     fixedpoint = llvm::maximum(fixedpoint, clampMin);
111 
112     llvm::APSInt result(storageBitWidth, !isSigned);
113     fixedpoint.convertToInteger(result, roundMode, &lossy);
114 
115     return std::move(result);
116   }
117 
quantizeFloatToInt64(APFloat expressedValue)118   int64_t quantizeFloatToInt64(APFloat expressedValue) const {
119     APInt qValue = quantizeFloatToInt(expressedValue);
120     return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
121   }
122 
~UniformQuantizedValueConverter()123   virtual ~UniformQuantizedValueConverter() {}
124 
125 private:
126   // An optimized implementation to quantize f32 to i8/u8 with C++ native
127   // arithmetic.
quantizeF32ToInt8(APFloat expressedValue)128   virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
129     assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
130     assert(storageBitWidth == 8);
131     assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
132 
133     const float realValue = expressedValue.convertToFloat();
134 
135     const double scaled = realValue / scaleDouble + zeroPointDouble;
136     // Round to nearest integer with halfway cases rounded away from zero.
137     const double scaledRounded = std::round(scaled);
138     const double clamped =
139         std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
140 
141     uint64_t signlessResult;
142     if (isSigned) {
143       int64_t clampedInt = static_cast<int8_t>(clamped);
144       memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
145     } else {
146       signlessResult = static_cast<uint8_t>(clamped);
147     }
148     return APInt(storageBitWidth, signlessResult);
149   }
150 
151   // Keep both APFloat and double versions of the quantization parameters
152   // around since they will be used in generic and specialized arithmetic,
153   // respectively.
154   const APFloat scale;
155   const APFloat zeroPoint;
156   const APFloat clampMin;
157   const APFloat clampMax;
158 
159   const double scaleDouble;
160   const double zeroPointDouble;
161   const double clampMinDouble;
162   const double clampMaxDouble;
163 
164   const uint32_t storageBitWidth;
165   const bool isSigned;
166   const llvm::APFloat::roundingMode roundMode;
167 };
168 
169 /// An utility class to quantize an attribute by the per-axis quantization
170 /// parameters. The size of the quantization dim in the converted elements
171 /// attribute should matche the size of of scales/zeroPoints vectors in the
172 /// quantization parameters.
173 class UniformQuantizedPerAxisValueConverter {
174 public:
UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)175   explicit UniformQuantizedPerAxisValueConverter(
176       UniformQuantizedPerAxisType uniformType)
177       : scales(uniformType.getScales()),
178         zeroPoints(uniformType.getZeroPoints()),
179         clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
180         clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
181         storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
182         isSigned(uniformType.isSigned()),
183         quantizationDim(uniformType.getQuantizedDimension()) {
184     assert(uniformType.getExpressedType().isa<FloatType>());
185     assert(uniformType.getStorageType().isSignlessInteger());
186     assert(scales.size() == zeroPoints.size());
187   }
188 
189   /// Quantize an Attribute by the quantization parameters. Return nullptr if
190   /// the conversion fails or the input array isn't an ElementsAttr.
191   ElementsAttr convert(Attribute realValue);
192 
193 private:
194   /// Quantize an DenseFPElementsAttr by the quantization parameters.
195   DenseElementsAttr convert(DenseFPElementsAttr attr);
196 
197   /// Get a uniform converter for the index-th chunk along the quantizationDim.
198   /// All the elements in this chunk is quantized by the returned converter.
getPerChunkConverter(int index)199   UniformQuantizedValueConverter getPerChunkConverter(int index) const {
200     UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
201                                              clampMin, clampMax,
202                                              storageBitWidth, isSigned);
203     return converter;
204   }
205 
206   const ArrayRef<double> scales;
207   const ArrayRef<int64_t> zeroPoints;
208   const APFloat clampMin;
209   const APFloat clampMax;
210   const uint32_t storageBitWidth;
211   const bool isSigned;
212   int32_t quantizationDim;
213 };
214 
215 } // namespace quant
216 } // namespace mlir
217 
218 #endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
219