1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h"
17
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
21
22 using namespace mlir;
23 using namespace mlir::quantfork;
24
25 /// Converts a possible primitive, real expressed value attribute to a
26 /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
27 /// quantizedElementType is the QuantizedType that describes the expressed
28 /// origValue.
29 /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)30 static Attribute convertPrimitiveValueAttr(
31 Attribute origRealValue, quant::QuantizedType quantizedElementType,
32 const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
33 if (origRealValue.isa<FloatAttr>()) {
34 FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
35 outConvertedType = quantizedElementType.getStorageType();
36 return IntegerAttr::get(quantizedElementType.getStorageType(),
37 converter.quantizeFloatToInt(floatAttr.getValue()));
38 }
39
40 return nullptr;
41 }
42
43 /// Converts a real expressed DenseFPElementsAttr to a corresponding
44 /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
45 /// storage values assuming the given quantizedElementType and converter.
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)46 static DenseElementsAttr convertDenseFPElementsAttr(
47 DenseFPElementsAttr realFPElementsAttr,
48 quant::QuantizedType quantizedElementType,
49 const UniformQuantizedValueConverter &converter) {
50 // Convert to corresponding quantized value attributes.
51 SmallVector<APInt, 8> quantValues;
52 if (realFPElementsAttr.isSplat()) {
53 quantValues.push_back(
54 converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
55 } else {
56 quantValues.reserve(realFPElementsAttr.getNumElements());
57 for (APFloat realVal : realFPElementsAttr) {
58 quantValues.push_back(converter.quantizeFloatToInt(realVal));
59 }
60 }
61
62 // Cast from an expressed-type-based type to storage-type-based type,
63 // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
64 ShapedType newDenseType =
65 quantizedElementType
66 .castExpressedToStorageType(realFPElementsAttr.getType())
67 .dyn_cast_or_null<ShapedType>();
68 if (!newDenseType) {
69 return nullptr;
70 }
71 return DenseIntElementsAttr::get(newDenseType, quantValues);
72 }
73
74 /// Converts a real expressed SplatElementsAttr to a corresponding
75 /// SplatElementsAttr containing quantized storage values assuming the given
76 /// quantizedElementType and converter.
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)77 static SparseElementsAttr convertSparseElementsAttr(
78 SparseElementsAttr realSparseAttr,
79 quant::QuantizedType quantizedElementType,
80 const UniformQuantizedValueConverter &converter) {
81 DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
82 if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
83 return nullptr;
84 }
85 DenseElementsAttr quantDenseAttr =
86 convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
87 quantizedElementType, converter);
88 if (!quantDenseAttr) {
89 return nullptr;
90 }
91
92 // Cast from an expressed-type-based type to storage-type-based type,
93 // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
94 ShapedType newSparseType =
95 quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
96 .dyn_cast_or_null<ShapedType>();
97 if (!newSparseType) {
98 return nullptr;
99 }
100 return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
101 quantDenseAttr);
102 }
103
104 /// Converts a real expressed Attribute to a corresponding Attribute containing
105 /// quantized storage values assuming the given uniform quantizedElementType and
106 /// converter.
quantizeAttrUniform(Attribute realValue,quant::UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)107 Attribute mlir::quantfork::quantizeAttrUniform(
108 Attribute realValue, quant::UniformQuantizedType quantizedElementType,
109 const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
110 // Fork to handle different variants of constants supported.
111 if (realValue.isa<DenseFPElementsAttr>()) {
112 // Dense tensor or vector constant.
113 auto converted = convertDenseFPElementsAttr(
114 realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
115 outConvertedType = converted.getType();
116 return converted;
117 }
118 if (realValue.isa<SparseElementsAttr>()) {
119 // Sparse tensor or vector constant.
120 auto converted = convertSparseElementsAttr(
121 realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
122 outConvertedType = converted.getType();
123 return converted;
124 }
125 // Nothing else matched: try to convert a primitive.
126 return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
127 outConvertedType);
128 }
129
130 /// Convert an attribute from a type based on
131 /// quantizedElementType.getExpressedType() to one based on
132 /// quantizedElementType.getStorageType().
133 /// Returns nullptr if the conversion is not supported.
134 /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,quant::QuantizedType quantizedElementType,Type & outConvertedType)135 Attribute mlir::quantfork::quantizeAttr(
136 Attribute realValue, quant::QuantizedType quantizedElementType,
137 Type &outConvertedType) {
138 if (auto uniformQuantized =
139 quantizedElementType.dyn_cast<quant::UniformQuantizedType>()) {
140 UniformQuantizedValueConverter converter(uniformQuantized);
141 return quantizeAttrUniform(realValue, uniformQuantized, converter,
142 outConvertedType);
143 }
144 if (auto uniformQuantizedPerAxis =
145 quantizedElementType.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
146 UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
147 auto converted = converter.convert(realValue);
148 // TODO: why we need this outConvertedType? remove it?
149 if (converted) {
150 outConvertedType = converted.getType();
151 }
152 return converted;
153 }
154 return nullptr;
155 }
156