• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
computeMultiplierAndShiftTosaScale16(double scale,int32_t & multiplier,int32_t & shift)22 static void computeMultiplierAndShiftTosaScale16(double scale,
23                                                  int32_t &multiplier,
24                                                  int32_t &shift) {
25 
26   const double mantissa = std::frexp(scale, &shift);
27   auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28 
29   // Can't be greater than 1.0.
30   assert(shiftedM <= (int64_t(1) << 15) &&
31          "Shifted mantissa exceeds 16 signed bits");
32 
33   if (shiftedM == (int64_t(1) << 15)) {
34     shiftedM /= 2;
35     shift++;
36   }
37 
38   // TOSA expects right shift to be positive and embed (1 << 15) into right
39   // shift bits.
40   shift = (-shift) + 15;
41 
42   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43          "Shifted mantissa exceeds 32-bit signed output type");
44 
45   multiplier = static_cast<int32_t>(shiftedM);
46 }
47 
48 /// From a scale value, generates multiplier and shift values where
49 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
50 /// multiplier = mantissa*2^shift for 32-bit scaling.
computeMultiplierAndShiftTosaScale32(double scale,int32_t & multiplier,int32_t & shift)51 static void computeMultiplierAndShiftTosaScale32(double scale,
52                                                  int32_t &multiplier,
53                                                  int32_t &shift) {
54 
55   const double mantissa = std::frexp(scale, &shift);
56   auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
57 
58   // Can't be greater than 1.0.
59   assert(shiftedM <= (int64_t(1) << 31) &&
60          "Shifted mantissa exceeds 32 signed bits");
61   if (shiftedM == (int64_t(1) << 31)) {
62     shiftedM /= 2;
63     shift++;
64   }
65 
66   // TOSA expects right shift to be positive, and embed (1 << 31) into right
67   // shift bits.
68   shift = (-shift) + 31;
69 
70   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
71          "Shifted mantissa exceeds 32-bit signed output type");
72 
73   multiplier = static_cast<int32_t>(shiftedM);
74 }
75 
76 /// Generates a quantized multiplier/shift from double.
computeMultiplierAndShift(double scale,int32_t & multiplier,int32_t & shift,int32_t scaleWidth)77 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
78                                            int32_t &shift, int32_t scaleWidth) {
79 
80   switch (scaleWidth) {
81   case 16:
82     computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
83     return;
84   case 32:
85     computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
86     return;
87   default:
88     assert(0 && "Unsupported Tosa quantized_scale regime specified!");
89   }
90 }
91 
92 #define GET_UQTYPE(input_type)                                                 \
93   ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
94 #define GET_QTYPE(input_type)                                                  \
95   ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
96 
97 /// Method to build ConvOpQuantizationAttr, called from
98 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
99 /// input_zp: input zeropoint
100 /// weight_zp: weight zeropoint.
101 ConvOpQuantizationAttr
buildConvOpQuantizationAttr(OpBuilder & builder,Value input,Value weight)102 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
103                                         Value weight) {
104 
105   auto inputType = input.getType().dyn_cast<RankedTensorType>();
106   auto weightType = weight.getType().dyn_cast<RankedTensorType>();
107 
108   if (!inputType || !weightType)
109     return nullptr;
110 
111   auto inputQType = GET_UQTYPE(inputType);
112   auto weightPerTensorQType = GET_UQTYPE(weightType);
113   auto weightPerAxisQType = weightType.getElementType()
114                                 .dyn_cast<quant::UniformQuantizedPerAxisType>();
115 
116   // Weights must be either per-tensor quantized or per-axis quantized.
117   assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
118          "Weights must be either per-tensor or per-axis quantized");
119 
120   // Either all quantized or all not quantized.
121   assert(!((bool)inputQType ^
122            ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
123          "Inputs and weights must be all quantized or all not quantized");
124 
125   if (inputQType) {
126 
127     int64_t inputZp = inputQType.getZeroPoint();
128     int64_t weightZp = 0;
129 
130     if (weightPerTensorQType) {
131       weightZp = weightPerTensorQType.getZeroPoint();
132     } else if (weightPerAxisQType) {
133       weightZp = weightPerAxisQType.getZeroPoints().front();
134     }
135 
136     auto quantAttr = tosa::ConvOpQuantizationAttr::get(
137         builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp),
138         builder.getContext());
139 
140     return quantAttr;
141   }
142 
143   return nullptr;
144 }
145 
146 /// Builds MatMulOpQuantizationAttr, called from
147 /// MatMulOpQuantInfoBuilder:
148 /// aZp: input a zeropoint
149 /// bZp: input b zeropoint.
150 MatMulOpQuantizationAttr
buildMatMulOpQuantizationAttr(OpBuilder & builder,Value a,Value b)151 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
152                                           Value b) {
153 
154   auto aType = a.getType().dyn_cast<RankedTensorType>();
155   auto bType = b.getType().dyn_cast<RankedTensorType>();
156 
157   if (!aType || !bType)
158     return nullptr;
159 
160   auto aQType = GET_UQTYPE(aType);
161   auto bQType = GET_UQTYPE(bType);
162 
163   // A and B are either all quantized or all not quantized.
164   assert(!((bool)aQType ^ (bool)bQType) &&
165          "Matmul operands must be all quantized or all not quantized");
166 
167   if (aQType) {
168 
169     int64_t aZp = aQType.getZeroPoint();
170     int64_t bZp = bQType.getZeroPoint();
171 
172     auto quantAttr = tosa::MatMulOpQuantizationAttr::get(
173         builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp),
174         builder.getContext());
175 
176     return quantAttr;
177   }
178 
179   return nullptr;
180 }
181 
182 /// Builds UnaryOpQuantizationAttr
183 /// UnaryOpQuantInfoBuilder:
184 /// inputZp: input zeropoint
185 /// outputZp: output zeropoint.
186 UnaryOpQuantizationAttr
buildUnaryOpQuantizationAttr(OpBuilder & builder,Value input,Type outputRawType)187 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
188                                          Type outputRawType) {
189 
190   auto inputType = input.getType().dyn_cast<RankedTensorType>();
191   auto outputType = outputRawType.dyn_cast<RankedTensorType>();
192 
193   if (!inputType || !outputType)
194     return nullptr;
195 
196   auto inputQType = GET_UQTYPE(inputType);
197   auto outputQType = GET_UQTYPE(outputType);
198 
199   // Either all quantized or all not quantized.
200   assert(!((bool)inputQType ^ (bool)outputQType) &&
201          "Unary inputs/outputs must be all quantized or all not quantized");
202 
203   if (inputQType) {
204 
205     int64_t inputZp = inputQType.getZeroPoint();
206     int64_t outputZp = outputQType.getZeroPoint();
207 
208     auto quantAttr = tosa::UnaryOpQuantizationAttr::get(
209         builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp),
210         builder.getContext());
211 
212     return quantAttr;
213   }
214 
215   return nullptr;
216 }
217 
218 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
219 /// inputZp: input zeropoint.
buildPadOpQuantizationAttr(OpBuilder & builder,Value input)220 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
221                                                              Value input) {
222 
223   auto inputType = input.getType().dyn_cast<RankedTensorType>();
224 
225   if (!inputType)
226     return nullptr;
227 
228   auto inputQType = GET_UQTYPE(inputType);
229 
230   if (inputQType) {
231 
232     int64_t inputZp = inputQType.getZeroPoint();
233 
234     auto quantAttr = tosa::PadOpQuantizationAttr::get(
235         builder.getI32IntegerAttr(inputZp), builder.getContext());
236 
237     return quantAttr;
238   }
239 
240   return nullptr;
241 }
242 
243 /// Builds output type for a quantized ConvOp with the right bitwidth.
244 /// This is called by the builder when dealing with quantized content.
buildConvOpResultTypeInfo(OpBuilder & builder,Type outputType,Value input,Value weight)245 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
246                                            Value input, Value weight) {
247 
248   auto inputType = input.getType().dyn_cast<RankedTensorType>();
249   auto weightType = weight.getType().dyn_cast<RankedTensorType>();
250 
251   assert(inputType && weightType &&
252          "Could not extract input or weight tensors from Conv op");
253 
254   auto inputQType = GET_QTYPE(inputType);
255   auto weightQType = GET_QTYPE(weightType);
256 
257   assert(inputQType && weightQType &&
258          "Could not extract input or weight tensor types from Conv op");
259 
260   unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
261   unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
262 
263   auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
264   assert(outputShapedType &&
265          "Could not extract output shape type from Conv op");
266 
267   auto outputShape = outputShapedType.getShape();
268 
269   IntegerType accElementType;
270   if (inputBits == 16 && weightBits == 8)
271     accElementType = builder.getIntegerType(48);
272   else
273     accElementType = builder.getI32Type();
274   auto accType = RankedTensorType::get(outputShape, accElementType);
275   return accType;
276 }
277 
278 /// Builds Tosa quantization attributes from min/max values.
buildQTypeFromMinMax(OpBuilder builder,Type inputDType,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)279 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
280                                       Attribute minAttr, Attribute maxAttr,
281                                       IntegerAttr quantBits, int filterQuantDim,
282                                       bool isSigned, BoolAttr narrowRange) {
283 
284   quant::QuantizedType retType;
285 
286   auto convfunc =
287       quant::ExpressedToQuantizedConverter::forInputType(inputDType);
288 
289   auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
290   auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
291 
292   SmallVector<double, 2> min, max;
293 
294   // At least one is per-axis quantized elementsattr.
295   if (minElems || maxElems) {
296     // Must have the same number of elements.
297     if (minElems.getNumElements() != maxElems.getNumElements())
298       return {};
299     min.reserve(minElems.getNumElements());
300     max.reserve(maxElems.getNumElements());
301     for (auto i : minElems)
302       min.push_back(FloatAttr::getValueAsDouble(i));
303     for (auto i : maxElems)
304       max.push_back(FloatAttr::getValueAsDouble(i));
305   } else { // Just a single FP value.
306     auto minVal = minAttr.dyn_cast<FloatAttr>();
307     if (minVal)
308       min.push_back(minVal.getValueAsDouble());
309     else
310       return {};
311     auto maxVal = maxAttr.dyn_cast<FloatAttr>();
312     if (maxVal)
313       max.push_back(maxVal.getValueAsDouble());
314     else
315       return {};
316   }
317 
318   if (min.size() == max.size()) {
319     if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
320       retType = quant::fakeQuantAttrsToType(
321           builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
322           narrowRange.getValue(), convfunc.expressedType, isSigned);
323     } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
324       auto shape = inputDType.dyn_cast<ShapedType>();
325       if (!shape)
326         return {};
327       if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
328         retType = quant::fakeQuantAttrsToType(
329             builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
330             max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
331       }
332     } else {
333       return {};
334     }
335   } else {
336     return {};
337   }
338 
339   if (!retType)
340     return {};
341 
342   return convfunc.convert(retType);
343 }
344 
345 /// Builds Tosa quantization attributes from min/max values.
346 TypeAttr
buildQTypeAttrFromMinMax(OpBuilder builder,Type inputDtype,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)347 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
348                                      Attribute minAttr, Attribute maxAttr,
349                                      IntegerAttr quantBits, int filterQuantDim,
350                                      bool isSigned, BoolAttr narrowRange) {
351 
352   return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
353                                             maxAttr, quantBits, filterQuantDim,
354                                             isSigned, narrowRange));
355 }
356