1 //===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantizeUtils.h"
10 #include "mlir/Dialect/Quant/UniformSupport.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13
14 using namespace mlir;
15 using namespace mlir::quant;
16
17 /// Converts a possible primitive, real expressed value attribute to a
18 /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
19 /// quantizedElementType is the QuantizedType that describes the expressed
20 /// origValue.
21 /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)22 static Attribute convertPrimitiveValueAttr(
23 Attribute origRealValue, QuantizedType quantizedElementType,
24 const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
25 if (origRealValue.isa<FloatAttr>()) {
26 FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
27 outConvertedType = quantizedElementType.getStorageType();
28 return IntegerAttr::get(quantizedElementType.getStorageType(),
29 converter.quantizeFloatToInt(floatAttr.getValue()));
30 }
31
32 return nullptr;
33 }
34
35 /// Converts a real expressed DenseFPElementsAttr to a corresponding
36 /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
37 /// storage values assuming the given quantizedElementType and converter.
38 static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)39 convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
40 QuantizedType quantizedElementType,
41 const UniformQuantizedValueConverter &converter) {
42 // Convert to corresponding quantized value attributes.
43 SmallVector<APInt, 8> quantValues;
44 if (realFPElementsAttr.isSplat()) {
45 quantValues.push_back(
46 converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
47 } else {
48 quantValues.reserve(realFPElementsAttr.getNumElements());
49 for (APFloat realVal : realFPElementsAttr) {
50 quantValues.push_back(converter.quantizeFloatToInt(realVal));
51 }
52 }
53
54 // Cast from an expressed-type-based type to storage-type-based type,
55 // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
56 ShapedType newDenseType =
57 quantizedElementType
58 .castExpressedToStorageType(realFPElementsAttr.getType())
59 .dyn_cast_or_null<ShapedType>();
60 if (!newDenseType) {
61 return nullptr;
62 }
63 return DenseIntElementsAttr::get(newDenseType, quantValues);
64 }
65
66 /// Converts a real expressed SplatElementsAttr to a corresponding
67 /// SplatElementsAttr containing quantized storage values assuming the given
68 /// quantizedElementType and converter.
69 static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)70 convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
71 QuantizedType quantizedElementType,
72 const UniformQuantizedValueConverter &converter) {
73 DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
74 if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
75 return nullptr;
76 }
77 DenseElementsAttr quantDenseAttr =
78 convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
79 quantizedElementType, converter);
80 if (!quantDenseAttr) {
81 return nullptr;
82 }
83
84 // Cast from an expressed-type-based type to storage-type-based type,
85 // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
86 ShapedType newSparseType =
87 quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
88 .dyn_cast_or_null<ShapedType>();
89 if (!newSparseType) {
90 return nullptr;
91 }
92 return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
93 quantDenseAttr);
94 }
95
96 /// Converts a real expressed Attribute to a corresponding Attribute containing
97 /// quantized storage values assuming the given uniform quantizedElementType and
98 /// converter.
quantizeAttrUniform(Attribute realValue,UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)99 Attribute mlir::quant::quantizeAttrUniform(
100 Attribute realValue, UniformQuantizedType quantizedElementType,
101 const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
102 // Fork to handle different variants of constants supported.
103 if (realValue.isa<DenseFPElementsAttr>()) {
104 // Dense tensor or vector constant.
105 auto converted = convertDenseFPElementsAttr(
106 realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
107 outConvertedType = converted.getType();
108 return converted;
109 } else if (realValue.isa<SparseElementsAttr>()) {
110 // Sparse tensor or vector constant.
111 auto converted = convertSparseElementsAttr(
112 realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
113 outConvertedType = converted.getType();
114 return converted;
115 } else {
116 // Nothing else matched: try to convert a primitive.
117 return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
118 outConvertedType);
119 }
120 }
121
122 /// Convert an attribute from a type based on
123 /// quantizedElementType.getExpressedType() to one based on
124 /// quantizedElementType.getStorageType().
125 /// Returns nullptr if the conversion is not supported.
126 /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,QuantizedType quantizedElementType,Type & outConvertedType)127 Attribute mlir::quant::quantizeAttr(Attribute realValue,
128 QuantizedType quantizedElementType,
129 Type &outConvertedType) {
130 if (auto uniformQuantized =
131 quantizedElementType.dyn_cast<UniformQuantizedType>()) {
132 UniformQuantizedValueConverter converter(uniformQuantized);
133 return quantizeAttrUniform(realValue, uniformQuantized, converter,
134 outConvertedType);
135
136 } else if (auto uniformQuantizedPerAxis =
137 quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
138 UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
139 auto converted = converter.convert(realValue);
140 // TODO: why we need this outConvertedType? remove it?
141 if (converted) {
142 outConvertedType = converted.getType();
143 }
144 return converted;
145 } else {
146 return nullptr;
147 }
148 }
149