• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
getFlags() const23 unsigned QuantizedType::getFlags() const {
24   return static_cast<ImplType *>(impl)->flags;
25 }
26 
classof(Type type)27 bool QuantizedType::classof(Type type) {
28   return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30 
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)31 LogicalResult QuantizedType::verifyConstructionInvariants(
32     Location loc, unsigned flags, Type storageType, Type expressedType,
33     int64_t storageTypeMin, int64_t storageTypeMax) {
34   // Verify that the storage type is integral.
35   // This restriction may be lifted at some point in favor of using bf16
36   // or f16 as exact representations on hardware where that is advantageous.
37   auto intStorageType = storageType.dyn_cast<IntegerType>();
38   if (!intStorageType)
39     return emitError(loc, "storage type must be integral");
40   unsigned integralWidth = intStorageType.getWidth();
41 
42   // Verify storage width.
43   if (integralWidth == 0 || integralWidth > MaxStorageBits)
44     return emitError(loc, "illegal storage type size: ") << integralWidth;
45 
46   // Verify storageTypeMin and storageTypeMax.
47   bool isSigned =
48       (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
49   int64_t defaultIntegerMin =
50       getDefaultMinimumForInteger(isSigned, integralWidth);
51   int64_t defaultIntegerMax =
52       getDefaultMaximumForInteger(isSigned, integralWidth);
53   if (storageTypeMax - storageTypeMin <= 0 ||
54       storageTypeMin < defaultIntegerMin ||
55       storageTypeMax > defaultIntegerMax) {
56     return emitError(loc, "illegal storage min and storage max: (")
57            << storageTypeMin << ":" << storageTypeMax << ")";
58   }
59   return success();
60 }
61 
getStorageType() const62 Type QuantizedType::getStorageType() const {
63   return static_cast<ImplType *>(impl)->storageType;
64 }
65 
getStorageTypeMin() const66 int64_t QuantizedType::getStorageTypeMin() const {
67   return static_cast<ImplType *>(impl)->storageTypeMin;
68 }
69 
getStorageTypeMax() const70 int64_t QuantizedType::getStorageTypeMax() const {
71   return static_cast<ImplType *>(impl)->storageTypeMax;
72 }
73 
getStorageTypeIntegralWidth() const74 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
75   // NOTE: If ever supporting non-integral storage types, some other scheme
76   // for determining the width will be needed.
77   return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
78 }
79 
getExpressedType() const80 Type QuantizedType::getExpressedType() const {
81   return static_cast<ImplType *>(impl)->expressedType;
82 }
83 
isCompatibleExpressedType(Type candidateExpressedType)84 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
85   if (candidateExpressedType.isa<ShapedType>()) {
86     return candidateExpressedType.cast<ShapedType>().getElementType() ==
87            getExpressedType();
88   }
89   return candidateExpressedType == getExpressedType();
90 }
91 
92 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)93 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
94   if (primitiveOrContainerType.isa<ShapedType>()) {
95     Type elementType =
96         primitiveOrContainerType.cast<ShapedType>().getElementType();
97     return elementType.dyn_cast<QuantizedType>();
98   }
99   return primitiveOrContainerType.dyn_cast<QuantizedType>();
100 }
101 
castFromStorageType(Type candidateType)102 Type QuantizedType::castFromStorageType(Type candidateType) {
103   if (candidateType == getStorageType()) {
104     // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
105     return *this;
106   } else if (candidateType.isa<RankedTensorType>()) {
107     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
108     return RankedTensorType::get(
109         candidateType.cast<RankedTensorType>().getShape(), getStorageType());
110   } else if (candidateType.isa<UnrankedTensorType>()) {
111     // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
112     return UnrankedTensorType::get(getStorageType());
113   } else if (candidateType.isa<VectorType>()) {
114     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
115     return VectorType::get(candidateType.cast<VectorType>().getShape(),
116                            getStorageType());
117   }
118 
119   return nullptr;
120 }
121 
castToStorageType(Type quantizedType)122 Type QuantizedType::castToStorageType(Type quantizedType) {
123   if (quantizedType.isa<QuantizedType>()) {
124     // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
125     return quantizedType.cast<QuantizedType>().getStorageType();
126   } else if (quantizedType.isa<ShapedType>()) {
127     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
128     ShapedType sType = quantizedType.cast<ShapedType>();
129     if (!sType.getElementType().isa<QuantizedType>()) {
130       return nullptr;
131     }
132     Type storageType =
133         sType.getElementType().cast<QuantizedType>().getStorageType();
134     if (quantizedType.isa<RankedTensorType>()) {
135       return RankedTensorType::get(sType.getShape(), storageType);
136     } else if (quantizedType.isa<UnrankedTensorType>()) {
137       return UnrankedTensorType::get(storageType);
138     } else if (quantizedType.isa<VectorType>()) {
139       return VectorType::get(sType.getShape(), storageType);
140     }
141   }
142 
143   return nullptr;
144 }
145 
castFromExpressedType(Type candidateType)146 Type QuantizedType::castFromExpressedType(Type candidateType) {
147   if (candidateType == getExpressedType()) {
148     // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
149     return *this;
150   } else if (candidateType.isa<ShapedType>()) {
151     ShapedType candidateShapedType = candidateType.cast<ShapedType>();
152     if (candidateShapedType.getElementType() != getExpressedType()) {
153       return nullptr;
154     }
155 
156     if (candidateType.isa<RankedTensorType>()) {
157       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
158       return RankedTensorType::get(candidateShapedType.getShape(), *this);
159     } else if (candidateType.isa<UnrankedTensorType>()) {
160       // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
161       return UnrankedTensorType::get(*this);
162     } else if (candidateType.isa<VectorType>()) {
163       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
164       return VectorType::get(candidateShapedType.getShape(), *this);
165     }
166   }
167 
168   return nullptr;
169 }
170 
castToExpressedType(Type quantizedType)171 Type QuantizedType::castToExpressedType(Type quantizedType) {
172   if (quantizedType.isa<QuantizedType>()) {
173     // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
174     return quantizedType.cast<QuantizedType>().getExpressedType();
175   } else if (quantizedType.isa<ShapedType>()) {
176     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
177     ShapedType sType = quantizedType.cast<ShapedType>();
178     if (!sType.getElementType().isa<QuantizedType>()) {
179       return nullptr;
180     }
181     Type expressedType =
182         sType.getElementType().cast<QuantizedType>().getExpressedType();
183     if (quantizedType.isa<RankedTensorType>()) {
184       return RankedTensorType::get(sType.getShape(), expressedType);
185     } else if (quantizedType.isa<UnrankedTensorType>()) {
186       return UnrankedTensorType::get(expressedType);
187     } else if (quantizedType.isa<VectorType>()) {
188       return VectorType::get(sType.getShape(), expressedType);
189     }
190   }
191 
192   return nullptr;
193 }
194 
castExpressedToStorageType(Type candidateType)195 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
196   Type expressedQuantizedType = castFromExpressedType(candidateType);
197   if (!expressedQuantizedType) {
198     return nullptr;
199   }
200   return QuantizedType::castToStorageType(expressedQuantizedType);
201 }
202 
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)203 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
204                                        Type expressedType,
205                                        int64_t storageTypeMin,
206                                        int64_t storageTypeMax) {
207   return Base::get(storageType.getContext(), flags, storageType, expressedType,
208                    storageTypeMin, storageTypeMax);
209 }
210 
getChecked(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax,Location location)211 AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
212                                               Type expressedType,
213                                               int64_t storageTypeMin,
214                                               int64_t storageTypeMax,
215                                               Location location) {
216   return Base::getChecked(location, flags, storageType, expressedType,
217                           storageTypeMin, storageTypeMax);
218 }
219 
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)220 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
221     Location loc, unsigned flags, Type storageType, Type expressedType,
222     int64_t storageTypeMin, int64_t storageTypeMax) {
223   if (failed(QuantizedType::verifyConstructionInvariants(
224           loc, flags, storageType, expressedType, storageTypeMin,
225           storageTypeMax))) {
226     return failure();
227   }
228 
229   // Verify that the expressed type is floating point.
230   // If this restriction is ever eliminated, the parser/printer must be
231   // extended.
232   if (expressedType && !expressedType.isa<FloatType>())
233     return emitError(loc, "expressed type must be floating point");
234 
235   return success();
236 }
237 
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)238 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
239                                                Type expressedType, double scale,
240                                                int64_t zeroPoint,
241                                                int64_t storageTypeMin,
242                                                int64_t storageTypeMax) {
243   return Base::get(storageType.getContext(), flags, storageType, expressedType,
244                    scale, zeroPoint, storageTypeMin, storageTypeMax);
245 }
246 
247 UniformQuantizedType
getChecked(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax,Location location)248 UniformQuantizedType::getChecked(unsigned flags, Type storageType,
249                                  Type expressedType, double scale,
250                                  int64_t zeroPoint, int64_t storageTypeMin,
251                                  int64_t storageTypeMax, Location location) {
252   return Base::getChecked(location, flags, storageType, expressedType, scale,
253                           zeroPoint, storageTypeMin, storageTypeMax);
254 }
255 
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)256 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
257     Location loc, unsigned flags, Type storageType, Type expressedType,
258     double scale, int64_t zeroPoint, int64_t storageTypeMin,
259     int64_t storageTypeMax) {
260   if (failed(QuantizedType::verifyConstructionInvariants(
261           loc, flags, storageType, expressedType, storageTypeMin,
262           storageTypeMax))) {
263     return failure();
264   }
265 
266   // Uniform quantization requires fully expressed parameters, including
267   // expressed type.
268   if (!expressedType)
269     return emitError(loc, "uniform quantization requires expressed type");
270 
271   // Verify that the expressed type is floating point.
272   // If this restriction is ever eliminated, the parser/printer must be
273   // extended.
274   if (!expressedType.isa<FloatType>())
275     return emitError(loc, "expressed type must be floating point");
276 
277   // Verify scale.
278   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
279     return emitError(loc, "illegal scale: ") << scale;
280 
281   return success();
282 }
283 
getScale() const284 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
285 
getZeroPoint() const286 int64_t UniformQuantizedType::getZeroPoint() const {
287   return getImpl()->zeroPoint;
288 }
289 
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)290 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
291     unsigned flags, Type storageType, Type expressedType,
292     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
293     int32_t quantizedDimension, int64_t storageTypeMin,
294     int64_t storageTypeMax) {
295   return Base::get(storageType.getContext(), flags, storageType, expressedType,
296                    scales, zeroPoints, quantizedDimension, storageTypeMin,
297                    storageTypeMax);
298 }
299 
getChecked(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax,Location location)300 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
301     unsigned flags, Type storageType, Type expressedType,
302     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
303     int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
304     Location location) {
305   return Base::getChecked(location, flags, storageType, expressedType, scales,
306                           zeroPoints, quantizedDimension, storageTypeMin,
307                           storageTypeMax);
308 }
309 
verifyConstructionInvariants(Location loc,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)310 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
311     Location loc, unsigned flags, Type storageType, Type expressedType,
312     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
313     int32_t quantizedDimension, int64_t storageTypeMin,
314     int64_t storageTypeMax) {
315   if (failed(QuantizedType::verifyConstructionInvariants(
316           loc, flags, storageType, expressedType, storageTypeMin,
317           storageTypeMax))) {
318     return failure();
319   }
320 
321   // Uniform quantization requires fully expressed parameters, including
322   // expressed type.
323   if (!expressedType)
324     return emitError(loc, "uniform quantization requires expressed type");
325 
326   // Verify that the expressed type is floating point.
327   // If this restriction is ever eliminated, the parser/printer must be
328   // extended.
329   if (!expressedType.isa<FloatType>())
330     return emitError(loc, "expressed type must be floating point");
331 
332   // Ensure that the number of scales and zeroPoints match.
333   if (scales.size() != zeroPoints.size())
334     return emitError(loc, "illegal number of scales and zeroPoints: ")
335            << scales.size() << ", " << zeroPoints.size();
336 
337   // Verify scale.
338   for (double scale : scales) {
339     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
340       return emitError(loc, "illegal scale: ") << scale;
341   }
342 
343   return success();
344 }
345 
getScales() const346 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
347   return getImpl()->getScales();
348 }
349 
getZeroPoints() const350 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
351   return getImpl()->getZeroPoints();
352 }
353 
getQuantizedDimension() const354 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
355   return getImpl()->quantizedDimension;
356 }
357 
get(Type expressedType,double min,double max)358 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
359                                                      double min, double max) {
360   return Base::get(expressedType.getContext(), expressedType, min, max);
361 }
362 
getChecked(Type expressedType,double min,double max,Location location)363 CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType,
364                                                             double min,
365                                                             double max,
366                                                             Location location) {
367   return Base::getChecked(location, expressedType, min, max);
368 }
369 
verifyConstructionInvariants(Location loc,Type expressedType,double min,double max)370 LogicalResult CalibratedQuantizedType::verifyConstructionInvariants(
371     Location loc, Type expressedType, double min, double max) {
372   // Verify that the expressed type is floating point.
373   // If this restriction is ever eliminated, the parser/printer must be
374   // extended.
375   if (!expressedType.isa<FloatType>())
376     return emitError(loc, "expressed type must be floating point");
377   if (max <= min)
378     return emitError(loc, "illegal min and max: (") << min << ":" << max << ")";
379 
380   return success();
381 }
382 
getMin() const383 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
384 
getMax() const385 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
386