1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15
16 using namespace mlir;
17 using namespace mlir::tosa;
18
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
computeMultiplierAndShiftTosaScale16(double scale,int32_t & multiplier,int32_t & shift)22 static void computeMultiplierAndShiftTosaScale16(double scale,
23 int32_t &multiplier,
24 int32_t &shift) {
25
26 const double mantissa = std::frexp(scale, &shift);
27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28
29 // Can't be greater than 1.0.
30 assert(shiftedM <= (int64_t(1) << 15) &&
31 "Shifted mantissa exceeds 16 signed bits");
32
33 if (shiftedM == (int64_t(1) << 15)) {
34 shiftedM /= 2;
35 shift++;
36 }
37
38 // TOSA expects right shift to be positive and embed (1 << 15) into right
39 // shift bits.
40 shift = (-shift) + 15;
41
42 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43 "Shifted mantissa exceeds 32-bit signed output type");
44
45 multiplier = static_cast<int32_t>(shiftedM);
46 }
47
48 /// From a scale value, generates multiplier and shift values where
49 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
50 /// multiplier = mantissa*2^shift for 32-bit scaling.
computeMultiplierAndShiftTosaScale32(double scale,int32_t & multiplier,int32_t & shift)51 static void computeMultiplierAndShiftTosaScale32(double scale,
52 int32_t &multiplier,
53 int32_t &shift) {
54
55 const double mantissa = std::frexp(scale, &shift);
56 auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
57
58 // Can't be greater than 1.0.
59 assert(shiftedM <= (int64_t(1) << 31) &&
60 "Shifted mantissa exceeds 32 signed bits");
61 if (shiftedM == (int64_t(1) << 31)) {
62 shiftedM /= 2;
63 shift++;
64 }
65
66 // TOSA expects right shift to be positive, and embed (1 << 31) into right
67 // shift bits.
68 shift = (-shift) + 31;
69
70 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
71 "Shifted mantissa exceeds 32-bit signed output type");
72
73 multiplier = static_cast<int32_t>(shiftedM);
74 }
75
76 /// Generates a quantized multiplier/shift from double.
computeMultiplierAndShift(double scale,int32_t & multiplier,int32_t & shift,int32_t scaleWidth)77 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
78 int32_t &shift, int32_t scaleWidth) {
79
80 switch (scaleWidth) {
81 case 16:
82 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
83 return;
84 case 32:
85 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
86 return;
87 default:
88 assert(0 && "Unsupported Tosa quantized_scale regime specified!");
89 }
90 }
91
92 #define GET_UQTYPE(input_type) \
93 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
94 #define GET_QTYPE(input_type) \
95 ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
96
97 /// Method to build ConvOpQuantizationAttr, called from
98 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
99 /// input_zp: input zeropoint
100 /// weight_zp: weight zeropoint.
101 ConvOpQuantizationAttr
buildConvOpQuantizationAttr(OpBuilder & builder,Value input,Value weight)102 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
103 Value weight) {
104
105 auto inputType = input.getType().dyn_cast<RankedTensorType>();
106 auto weightType = weight.getType().dyn_cast<RankedTensorType>();
107
108 if (!inputType || !weightType)
109 return nullptr;
110
111 auto inputQType = GET_UQTYPE(inputType);
112 auto weightPerTensorQType = GET_UQTYPE(weightType);
113 auto weightPerAxisQType = weightType.getElementType()
114 .dyn_cast<quant::UniformQuantizedPerAxisType>();
115
116 // Weights must be either per-tensor quantized or per-axis quantized.
117 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
118 "Weights must be either per-tensor or per-axis quantized");
119
120 // Either all quantized or all not quantized.
121 assert(!((bool)inputQType ^
122 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
123 "Inputs and weights must be all quantized or all not quantized");
124
125 if (inputQType) {
126
127 int64_t inputZp = inputQType.getZeroPoint();
128 int64_t weightZp = 0;
129
130 if (weightPerTensorQType) {
131 weightZp = weightPerTensorQType.getZeroPoint();
132 } else if (weightPerAxisQType) {
133 weightZp = weightPerAxisQType.getZeroPoints().front();
134 }
135
136 auto quantAttr = tosa::ConvOpQuantizationAttr::get(
137 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp),
138 builder.getContext());
139
140 return quantAttr;
141 }
142
143 return nullptr;
144 }
145
146 /// Builds MatMulOpQuantizationAttr, called from
147 /// MatMulOpQuantInfoBuilder:
148 /// aZp: input a zeropoint
149 /// bZp: input b zeropoint.
150 MatMulOpQuantizationAttr
buildMatMulOpQuantizationAttr(OpBuilder & builder,Value a,Value b)151 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
152 Value b) {
153
154 auto aType = a.getType().dyn_cast<RankedTensorType>();
155 auto bType = b.getType().dyn_cast<RankedTensorType>();
156
157 if (!aType || !bType)
158 return nullptr;
159
160 auto aQType = GET_UQTYPE(aType);
161 auto bQType = GET_UQTYPE(bType);
162
163 // A and B are either all quantized or all not quantized.
164 assert(!((bool)aQType ^ (bool)bQType) &&
165 "Matmul operands must be all quantized or all not quantized");
166
167 if (aQType) {
168
169 int64_t aZp = aQType.getZeroPoint();
170 int64_t bZp = bQType.getZeroPoint();
171
172 auto quantAttr = tosa::MatMulOpQuantizationAttr::get(
173 builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp),
174 builder.getContext());
175
176 return quantAttr;
177 }
178
179 return nullptr;
180 }
181
182 /// Builds UnaryOpQuantizationAttr
183 /// UnaryOpQuantInfoBuilder:
184 /// inputZp: input zeropoint
185 /// outputZp: output zeropoint.
186 UnaryOpQuantizationAttr
buildUnaryOpQuantizationAttr(OpBuilder & builder,Value input,Type outputRawType)187 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
188 Type outputRawType) {
189
190 auto inputType = input.getType().dyn_cast<RankedTensorType>();
191 auto outputType = outputRawType.dyn_cast<RankedTensorType>();
192
193 if (!inputType || !outputType)
194 return nullptr;
195
196 auto inputQType = GET_UQTYPE(inputType);
197 auto outputQType = GET_UQTYPE(outputType);
198
199 // Either all quantized or all not quantized.
200 assert(!((bool)inputQType ^ (bool)outputQType) &&
201 "Unary inputs/outputs must be all quantized or all not quantized");
202
203 if (inputQType) {
204
205 int64_t inputZp = inputQType.getZeroPoint();
206 int64_t outputZp = outputQType.getZeroPoint();
207
208 auto quantAttr = tosa::UnaryOpQuantizationAttr::get(
209 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp),
210 builder.getContext());
211
212 return quantAttr;
213 }
214
215 return nullptr;
216 }
217
218 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
219 /// inputZp: input zeropoint.
buildPadOpQuantizationAttr(OpBuilder & builder,Value input)220 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
221 Value input) {
222
223 auto inputType = input.getType().dyn_cast<RankedTensorType>();
224
225 if (!inputType)
226 return nullptr;
227
228 auto inputQType = GET_UQTYPE(inputType);
229
230 if (inputQType) {
231
232 int64_t inputZp = inputQType.getZeroPoint();
233
234 auto quantAttr = tosa::PadOpQuantizationAttr::get(
235 builder.getI32IntegerAttr(inputZp), builder.getContext());
236
237 return quantAttr;
238 }
239
240 return nullptr;
241 }
242
243 /// Builds output type for a quantized ConvOp with the right bitwidth.
244 /// This is called by the builder when dealing with quantized content.
buildConvOpResultTypeInfo(OpBuilder & builder,Type outputType,Value input,Value weight)245 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
246 Value input, Value weight) {
247
248 auto inputType = input.getType().dyn_cast<RankedTensorType>();
249 auto weightType = weight.getType().dyn_cast<RankedTensorType>();
250
251 assert(inputType && weightType &&
252 "Could not extract input or weight tensors from Conv op");
253
254 auto inputQType = GET_QTYPE(inputType);
255 auto weightQType = GET_QTYPE(weightType);
256
257 assert(inputQType && weightQType &&
258 "Could not extract input or weight tensor types from Conv op");
259
260 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
261 unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
262
263 auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
264 assert(outputShapedType &&
265 "Could not extract output shape type from Conv op");
266
267 auto outputShape = outputShapedType.getShape();
268
269 IntegerType accElementType;
270 if (inputBits == 16 && weightBits == 8)
271 accElementType = builder.getIntegerType(48);
272 else
273 accElementType = builder.getI32Type();
274 auto accType = RankedTensorType::get(outputShape, accElementType);
275 return accType;
276 }
277
278 /// Builds Tosa quantization attributes from min/max values.
buildQTypeFromMinMax(OpBuilder builder,Type inputDType,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)279 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
280 Attribute minAttr, Attribute maxAttr,
281 IntegerAttr quantBits, int filterQuantDim,
282 bool isSigned, BoolAttr narrowRange) {
283
284 quant::QuantizedType retType;
285
286 auto convfunc =
287 quant::ExpressedToQuantizedConverter::forInputType(inputDType);
288
289 auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
290 auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
291
292 SmallVector<double, 2> min, max;
293
294 // At least one is per-axis quantized elementsattr.
295 if (minElems || maxElems) {
296 // Must have the same number of elements.
297 if (minElems.getNumElements() != maxElems.getNumElements())
298 return {};
299 min.reserve(minElems.getNumElements());
300 max.reserve(maxElems.getNumElements());
301 for (auto i : minElems)
302 min.push_back(FloatAttr::getValueAsDouble(i));
303 for (auto i : maxElems)
304 max.push_back(FloatAttr::getValueAsDouble(i));
305 } else { // Just a single FP value.
306 auto minVal = minAttr.dyn_cast<FloatAttr>();
307 if (minVal)
308 min.push_back(minVal.getValueAsDouble());
309 else
310 return {};
311 auto maxVal = maxAttr.dyn_cast<FloatAttr>();
312 if (maxVal)
313 max.push_back(maxVal.getValueAsDouble());
314 else
315 return {};
316 }
317
318 if (min.size() == max.size()) {
319 if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
320 retType = quant::fakeQuantAttrsToType(
321 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
322 narrowRange.getValue(), convfunc.expressedType, isSigned);
323 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
324 auto shape = inputDType.dyn_cast<ShapedType>();
325 if (!shape)
326 return {};
327 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
328 retType = quant::fakeQuantAttrsToType(
329 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
330 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
331 }
332 } else {
333 return {};
334 }
335 } else {
336 return {};
337 }
338
339 if (!retType)
340 return {};
341
342 return convfunc.convert(retType);
343 }
344
345 /// Builds Tosa quantization attributes from min/max values.
346 TypeAttr
buildQTypeAttrFromMinMax(OpBuilder builder,Type inputDtype,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)347 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
348 Attribute minAttr, Attribute maxAttr,
349 IntegerAttr quantBits, int filterQuantDim,
350 bool isSigned, BoolAttr narrowRange) {
351
352 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
353 maxAttr, quantBits, filterQuantDim,
354 isSigned, narrowRange));
355 }
356