1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11
12 using namespace mlir;
13 using namespace mlir::quant;
14
getDefaultStorageParams(unsigned numBits,bool narrowRange,bool isSigned,MLIRContext * ctx,Type & storageType,int64_t & qmin,int64_t & qmax)15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16 bool isSigned, MLIRContext *ctx,
17 Type &storageType, int64_t &qmin,
18 int64_t &qmax) {
19 // Hard-coded type mapping from TFLite.
20 if (numBits <= 8) {
21 storageType = IntegerType::get(8, ctx);
22 if (isSigned) {
23 qmin = -128;
24 qmax = 127;
25 } else {
26 qmin = 0;
27 qmax = 255;
28 }
29 } else if (numBits <= 16) {
30 storageType = IntegerType::get(16, ctx);
31 if (isSigned) {
32 qmin = -32768;
33 qmax = 32767;
34 } else {
35 qmin = 0;
36 qmax = 65535;
37 }
38 } else if (numBits <= 32) {
39 storageType = IntegerType::get(32, ctx);
40 if (isSigned) {
41 qmin = std::numeric_limits<int32_t>::min();
42 qmax = std::numeric_limits<int32_t>::max();
43 } else {
44 qmin = std::numeric_limits<uint32_t>::min();
45 qmax = std::numeric_limits<uint32_t>::max();
46 }
47 } else {
48 return true;
49 }
50
51 // Handle narrowRange.
52 if (narrowRange) {
53 qmin += 1;
54 }
55 return false;
56 }
57
58 // This is a specific implementation of nudging:
59 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61 // point is derived from the shifted range, and the scale isn't changed. As
62 // a consequence some values, which are supposed in the original [rmin, rmax]
63 // range will be outside the shifted range and be clamped during quantization.
64 // TODO: we should nudge the scale as well, but that requires the
65 // fake quant op used in the training to use the nudged scale as well.
getNudgedScaleAndZeroPoint(int64_t qmin,int64_t qmax,double rmin,double rmax,double & scale,int64_t & nudgedZeroPoint)66 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67 double rmax, double &scale,
68 int64_t &nudgedZeroPoint) {
69 // Determine the scale.
70 const double qminDouble = qmin;
71 const double qmaxDouble = qmax;
72 scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73
74 // Zero point computation.
75 // In float, solve the affine equation for any known pair
76 // (real value, corresponding quantized value), of which, two such pairs
77 // are known: (rmin, qmin), (rmax, qmax).
78 // The arithmetic error on the zero point computed from either pair will be
79 // roughly machine_epsilon * (sum of absolute values of terms).
80 // Use the variant that adds the smaller error.
81 const double zeroPointFromMin = qminDouble - rmin / scale;
82 const double zeroPointFromMinError =
83 std::abs(qminDouble) + std::abs(rmin / scale);
84 const double zeroPointFromMax = qmaxDouble - rmax / scale;
85 const double zeroPointFromMaxError =
86 std::abs(qmaxDouble) + std::abs(rmax / scale);
87
88 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89 ? zeroPointFromMin
90 : zeroPointFromMax;
91
92 // Now nudge the zero point to be an integer.
93 nudgedZeroPoint = 0;
94 if (zeroPointDouble < qminDouble) {
95 nudgedZeroPoint = qmin;
96 } else if (zeroPointDouble > qmaxDouble) {
97 nudgedZeroPoint = qmax;
98 } else {
99 nudgedZeroPoint = round(zeroPointDouble);
100 }
101
102 // By construction, the nudged zero point should always be in range.
103 assert(nudgedZeroPoint >= qmin);
104 assert(nudgedZeroPoint <= qmax);
105 }
106
107 UniformQuantizedType
fakeQuantAttrsToType(Location loc,unsigned numBits,double rmin,double rmax,bool narrowRange,Type expressedType,bool isSigned)108 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109 double rmax, bool narrowRange,
110 Type expressedType, bool isSigned) {
111 MLIRContext *ctx = expressedType.getContext();
112 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113 Type storageType;
114 int64_t qmin;
115 int64_t qmax;
116 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117 qmin, qmax)) {
118 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119 nullptr);
120 }
121
122 // Special case where min/max is close enough. The tensor contents are all
123 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124 // points and dequantized to 0.0.
125 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
126 return UniformQuantizedType::getChecked(flags, storageType, expressedType,
127 1.0, qmin, qmin, qmax, loc);
128 }
129
130 double scale;
131 int64_t nudgedZeroPoint;
132 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133
134 return UniformQuantizedType::getChecked(flags, storageType, expressedType,
135 scale, nudgedZeroPoint, qmin, qmax,
136 loc);
137 }
138
fakeQuantAttrsToType(Location loc,unsigned numBits,int32_t quantizedDimension,ArrayRef<double> rmins,ArrayRef<double> rmaxs,bool narrowRange,Type expressedType,bool isSigned)139 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
140 Location loc, unsigned numBits, int32_t quantizedDimension,
141 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142 Type expressedType, bool isSigned) {
143 size_t axis_size = rmins.size();
144 if (axis_size != rmaxs.size()) {
145 return (emitError(loc, "mismatched per-axis min and max size: ")
146 << axis_size << " vs. " << rmaxs.size(),
147 nullptr);
148 }
149
150 MLIRContext *ctx = expressedType.getContext();
151 Type storageType;
152 int64_t qmin;
153 int64_t qmax;
154 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155 qmin, qmax)) {
156 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157 nullptr);
158 }
159
160 SmallVector<double, 4> scales;
161 SmallVector<int64_t, 4> zeroPoints;
162 scales.reserve(axis_size);
163 zeroPoints.reserve(axis_size);
164 for (size_t axis = 0; axis != axis_size; ++axis) {
165 double rmin = rmins[axis];
166 double rmax = rmaxs[axis];
167 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168 scales.push_back(1.0);
169 zeroPoints.push_back(qmin);
170 continue;
171 }
172
173 double scale;
174 int64_t nudgedZeroPoint;
175 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176 scales.push_back(scale);
177 zeroPoints.push_back(nudgedZeroPoint);
178 }
179
180 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
181 return UniformQuantizedPerAxisType::getChecked(
182 flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
183 qmin, qmax, loc);
184 }
185