1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22
getFlags() const23 unsigned QuantizedType::getFlags() const {
24 return static_cast<ImplType *>(impl)->flags;
25 }
26
classof(Type type)27 bool QuantizedType::classof(Type type) {
28 return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)31 LogicalResult QuantizedType::verifyConstructionInvariants(
32 Location loc, unsigned flags, Type storageType, Type expressedType,
33 int64_t storageTypeMin, int64_t storageTypeMax) {
34 // Verify that the storage type is integral.
35 // This restriction may be lifted at some point in favor of using bf16
36 // or f16 as exact representations on hardware where that is advantageous.
37 auto intStorageType = storageType.dyn_cast<IntegerType>();
38 if (!intStorageType)
39 return emitError(loc, "storage type must be integral");
40 unsigned integralWidth = intStorageType.getWidth();
41
42 // Verify storage width.
43 if (integralWidth == 0 || integralWidth > MaxStorageBits)
44 return emitError(loc, "illegal storage type size: ") << integralWidth;
45
46 // Verify storageTypeMin and storageTypeMax.
47 bool isSigned =
48 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
49 int64_t defaultIntegerMin =
50 getDefaultMinimumForInteger(isSigned, integralWidth);
51 int64_t defaultIntegerMax =
52 getDefaultMaximumForInteger(isSigned, integralWidth);
53 if (storageTypeMax - storageTypeMin <= 0 ||
54 storageTypeMin < defaultIntegerMin ||
55 storageTypeMax > defaultIntegerMax) {
56 return emitError(loc, "illegal storage min and storage max: (")
57 << storageTypeMin << ":" << storageTypeMax << ")";
58 }
59 return success();
60 }
61
getStorageType() const62 Type QuantizedType::getStorageType() const {
63 return static_cast<ImplType *>(impl)->storageType;
64 }
65
getStorageTypeMin() const66 int64_t QuantizedType::getStorageTypeMin() const {
67 return static_cast<ImplType *>(impl)->storageTypeMin;
68 }
69
getStorageTypeMax() const70 int64_t QuantizedType::getStorageTypeMax() const {
71 return static_cast<ImplType *>(impl)->storageTypeMax;
72 }
73
getStorageTypeIntegralWidth() const74 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
75 // NOTE: If ever supporting non-integral storage types, some other scheme
76 // for determining the width will be needed.
77 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
78 }
79
getExpressedType() const80 Type QuantizedType::getExpressedType() const {
81 return static_cast<ImplType *>(impl)->expressedType;
82 }
83
isCompatibleExpressedType(Type candidateExpressedType)84 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
85 if (candidateExpressedType.isa<ShapedType>()) {
86 return candidateExpressedType.cast<ShapedType>().getElementType() ==
87 getExpressedType();
88 }
89 return candidateExpressedType == getExpressedType();
90 }
91
92 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)93 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
94 if (primitiveOrContainerType.isa<ShapedType>()) {
95 Type elementType =
96 primitiveOrContainerType.cast<ShapedType>().getElementType();
97 return elementType.dyn_cast<QuantizedType>();
98 }
99 return primitiveOrContainerType.dyn_cast<QuantizedType>();
100 }
101
castFromStorageType(Type candidateType)102 Type QuantizedType::castFromStorageType(Type candidateType) {
103 if (candidateType == getStorageType()) {
104 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
105 return *this;
106 } else if (candidateType.isa<RankedTensorType>()) {
107 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
108 return RankedTensorType::get(
109 candidateType.cast<RankedTensorType>().getShape(), getStorageType());
110 } else if (candidateType.isa<UnrankedTensorType>()) {
111 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
112 return UnrankedTensorType::get(getStorageType());
113 } else if (candidateType.isa<VectorType>()) {
114 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
115 return VectorType::get(candidateType.cast<VectorType>().getShape(),
116 getStorageType());
117 }
118
119 return nullptr;
120 }
121
castToStorageType(Type quantizedType)122 Type QuantizedType::castToStorageType(Type quantizedType) {
123 if (quantizedType.isa<QuantizedType>()) {
124 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
125 return quantizedType.cast<QuantizedType>().getStorageType();
126 } else if (quantizedType.isa<ShapedType>()) {
127 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
128 ShapedType sType = quantizedType.cast<ShapedType>();
129 if (!sType.getElementType().isa<QuantizedType>()) {
130 return nullptr;
131 }
132 Type storageType =
133 sType.getElementType().cast<QuantizedType>().getStorageType();
134 if (quantizedType.isa<RankedTensorType>()) {
135 return RankedTensorType::get(sType.getShape(), storageType);
136 } else if (quantizedType.isa<UnrankedTensorType>()) {
137 return UnrankedTensorType::get(storageType);
138 } else if (quantizedType.isa<VectorType>()) {
139 return VectorType::get(sType.getShape(), storageType);
140 }
141 }
142
143 return nullptr;
144 }
145
castFromExpressedType(Type candidateType)146 Type QuantizedType::castFromExpressedType(Type candidateType) {
147 if (candidateType == getExpressedType()) {
148 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
149 return *this;
150 } else if (candidateType.isa<ShapedType>()) {
151 ShapedType candidateShapedType = candidateType.cast<ShapedType>();
152 if (candidateShapedType.getElementType() != getExpressedType()) {
153 return nullptr;
154 }
155
156 if (candidateType.isa<RankedTensorType>()) {
157 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
158 return RankedTensorType::get(candidateShapedType.getShape(), *this);
159 } else if (candidateType.isa<UnrankedTensorType>()) {
160 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
161 return UnrankedTensorType::get(*this);
162 } else if (candidateType.isa<VectorType>()) {
163 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
164 return VectorType::get(candidateShapedType.getShape(), *this);
165 }
166 }
167
168 return nullptr;
169 }
170
castToExpressedType(Type quantizedType)171 Type QuantizedType::castToExpressedType(Type quantizedType) {
172 if (quantizedType.isa<QuantizedType>()) {
173 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
174 return quantizedType.cast<QuantizedType>().getExpressedType();
175 } else if (quantizedType.isa<ShapedType>()) {
176 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
177 ShapedType sType = quantizedType.cast<ShapedType>();
178 if (!sType.getElementType().isa<QuantizedType>()) {
179 return nullptr;
180 }
181 Type expressedType =
182 sType.getElementType().cast<QuantizedType>().getExpressedType();
183 if (quantizedType.isa<RankedTensorType>()) {
184 return RankedTensorType::get(sType.getShape(), expressedType);
185 } else if (quantizedType.isa<UnrankedTensorType>()) {
186 return UnrankedTensorType::get(expressedType);
187 } else if (quantizedType.isa<VectorType>()) {
188 return VectorType::get(sType.getShape(), expressedType);
189 }
190 }
191
192 return nullptr;
193 }
194
castExpressedToStorageType(Type candidateType)195 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
196 Type expressedQuantizedType = castFromExpressedType(candidateType);
197 if (!expressedQuantizedType) {
198 return nullptr;
199 }
200 return QuantizedType::castToStorageType(expressedQuantizedType);
201 }
202
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)203 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
204 Type expressedType,
205 int64_t storageTypeMin,
206 int64_t storageTypeMax) {
207 return Base::get(storageType.getContext(), flags, storageType, expressedType,
208 storageTypeMin, storageTypeMax);
209 }
210
getChecked(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax,Location location)211 AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
212 Type expressedType,
213 int64_t storageTypeMin,
214 int64_t storageTypeMax,
215 Location location) {
216 return Base::getChecked(location, flags, storageType, expressedType,
217 storageTypeMin, storageTypeMax);
218 }
219
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)220 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
221 Location loc, unsigned flags, Type storageType, Type expressedType,
222 int64_t storageTypeMin, int64_t storageTypeMax) {
223 if (failed(QuantizedType::verifyConstructionInvariants(
224 loc, flags, storageType, expressedType, storageTypeMin,
225 storageTypeMax))) {
226 return failure();
227 }
228
229 // Verify that the expressed type is floating point.
230 // If this restriction is ever eliminated, the parser/printer must be
231 // extended.
232 if (expressedType && !expressedType.isa<FloatType>())
233 return emitError(loc, "expressed type must be floating point");
234
235 return success();
236 }
237
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)238 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
239 Type expressedType, double scale,
240 int64_t zeroPoint,
241 int64_t storageTypeMin,
242 int64_t storageTypeMax) {
243 return Base::get(storageType.getContext(), flags, storageType, expressedType,
244 scale, zeroPoint, storageTypeMin, storageTypeMax);
245 }
246
247 UniformQuantizedType
getChecked(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax,Location location)248 UniformQuantizedType::getChecked(unsigned flags, Type storageType,
249 Type expressedType, double scale,
250 int64_t zeroPoint, int64_t storageTypeMin,
251 int64_t storageTypeMax, Location location) {
252 return Base::getChecked(location, flags, storageType, expressedType, scale,
253 zeroPoint, storageTypeMin, storageTypeMax);
254 }
255
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)256 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
257 Location loc, unsigned flags, Type storageType, Type expressedType,
258 double scale, int64_t zeroPoint, int64_t storageTypeMin,
259 int64_t storageTypeMax) {
260 if (failed(QuantizedType::verifyConstructionInvariants(
261 loc, flags, storageType, expressedType, storageTypeMin,
262 storageTypeMax))) {
263 return failure();
264 }
265
266 // Uniform quantization requires fully expressed parameters, including
267 // expressed type.
268 if (!expressedType)
269 return emitError(loc, "uniform quantization requires expressed type");
270
271 // Verify that the expressed type is floating point.
272 // If this restriction is ever eliminated, the parser/printer must be
273 // extended.
274 if (!expressedType.isa<FloatType>())
275 return emitError(loc, "expressed type must be floating point");
276
277 // Verify scale.
278 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
279 return emitError(loc, "illegal scale: ") << scale;
280
281 return success();
282 }
283
getScale() const284 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
285
getZeroPoint() const286 int64_t UniformQuantizedType::getZeroPoint() const {
287 return getImpl()->zeroPoint;
288 }
289
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)290 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
291 unsigned flags, Type storageType, Type expressedType,
292 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
293 int32_t quantizedDimension, int64_t storageTypeMin,
294 int64_t storageTypeMax) {
295 return Base::get(storageType.getContext(), flags, storageType, expressedType,
296 scales, zeroPoints, quantizedDimension, storageTypeMin,
297 storageTypeMax);
298 }
299
getChecked(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax,Location location)300 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
301 unsigned flags, Type storageType, Type expressedType,
302 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
303 int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
304 Location location) {
305 return Base::getChecked(location, flags, storageType, expressedType, scales,
306 zeroPoints, quantizedDimension, storageTypeMin,
307 storageTypeMax);
308 }
309
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)310 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
311 Location loc, unsigned flags, Type storageType, Type expressedType,
312 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
313 int32_t quantizedDimension, int64_t storageTypeMin,
314 int64_t storageTypeMax) {
315 if (failed(QuantizedType::verifyConstructionInvariants(
316 loc, flags, storageType, expressedType, storageTypeMin,
317 storageTypeMax))) {
318 return failure();
319 }
320
321 // Uniform quantization requires fully expressed parameters, including
322 // expressed type.
323 if (!expressedType)
324 return emitError(loc, "uniform quantization requires expressed type");
325
326 // Verify that the expressed type is floating point.
327 // If this restriction is ever eliminated, the parser/printer must be
328 // extended.
329 if (!expressedType.isa<FloatType>())
330 return emitError(loc, "expressed type must be floating point");
331
332 // Ensure that the number of scales and zeroPoints match.
333 if (scales.size() != zeroPoints.size())
334 return emitError(loc, "illegal number of scales and zeroPoints: ")
335 << scales.size() << ", " << zeroPoints.size();
336
337 // Verify scale.
338 for (double scale : scales) {
339 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
340 return emitError(loc, "illegal scale: ") << scale;
341 }
342
343 return success();
344 }
345
getScales() const346 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
347 return getImpl()->getScales();
348 }
349
getZeroPoints() const350 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
351 return getImpl()->getZeroPoints();
352 }
353
getQuantizedDimension() const354 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
355 return getImpl()->quantizedDimension;
356 }
357
get(Type expressedType,double min,double max)358 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
359 double min, double max) {
360 return Base::get(expressedType.getContext(), expressedType, min, max);
361 }
362
getChecked(Type expressedType,double min,double max,Location location)363 CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType,
364 double min,
365 double max,
366 Location location) {
367 return Base::getChecked(location, expressedType, min, max);
368 }
369
verifyConstructionInvariants(Location loc,Type expressedType,double min,double max)370 LogicalResult CalibratedQuantizedType::verifyConstructionInvariants(
371 Location loc, Type expressedType, double min, double max) {
372 // Verify that the expressed type is floating point.
373 // If this restriction is ever eliminated, the parser/printer must be
374 // extended.
375 if (!expressedType.isa<FloatType>())
376 return emitError(loc, "expressed type must be floating point");
377 if (max <= min)
378 return emitError(loc, "illegal min and max: (") << min << ":" << max << ")";
379
380 return success();
381 }
382
getMin() const383 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
384
getMax() const385 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
386