• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h"
17 
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
21 
22 using namespace mlir;
23 using namespace mlir::quantfork;
24 
25 /// Converts a possible primitive, real expressed value attribute to a
26 /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
27 /// quantizedElementType is the QuantizedType that describes the expressed
28 /// origValue.
29 /// Returns a converter Attribute or nullptr if conversion is not possible.
convertPrimitiveValueAttr(Attribute origRealValue,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)30 static Attribute convertPrimitiveValueAttr(
31     Attribute origRealValue, quant::QuantizedType quantizedElementType,
32     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
33   if (origRealValue.isa<FloatAttr>()) {
34     FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
35     outConvertedType = quantizedElementType.getStorageType();
36     return IntegerAttr::get(quantizedElementType.getStorageType(),
37                             converter.quantizeFloatToInt(floatAttr.getValue()));
38   }
39 
40   return nullptr;
41 }
42 
43 /// Converts a real expressed DenseFPElementsAttr to a corresponding
44 /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
45 /// storage values assuming the given quantizedElementType and converter.
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)46 static DenseElementsAttr convertDenseFPElementsAttr(
47     DenseFPElementsAttr realFPElementsAttr,
48     quant::QuantizedType quantizedElementType,
49     const UniformQuantizedValueConverter &converter) {
50   // Convert to corresponding quantized value attributes.
51   SmallVector<APInt, 8> quantValues;
52   if (realFPElementsAttr.isSplat()) {
53     quantValues.push_back(
54         converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
55   } else {
56     quantValues.reserve(realFPElementsAttr.getNumElements());
57     for (APFloat realVal : realFPElementsAttr) {
58       quantValues.push_back(converter.quantizeFloatToInt(realVal));
59     }
60   }
61 
62   // Cast from an expressed-type-based type to storage-type-based type,
63   // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
64   ShapedType newDenseType =
65       quantizedElementType
66           .castExpressedToStorageType(realFPElementsAttr.getType())
67           .dyn_cast_or_null<ShapedType>();
68   if (!newDenseType) {
69     return nullptr;
70   }
71   return DenseIntElementsAttr::get(newDenseType, quantValues);
72 }
73 
74 /// Converts a real expressed SplatElementsAttr to a corresponding
75 /// SplatElementsAttr containing quantized storage values assuming the given
76 /// quantizedElementType and converter.
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,quant::QuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter)77 static SparseElementsAttr convertSparseElementsAttr(
78     SparseElementsAttr realSparseAttr,
79     quant::QuantizedType quantizedElementType,
80     const UniformQuantizedValueConverter &converter) {
81   DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
82   if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
83     return nullptr;
84   }
85   DenseElementsAttr quantDenseAttr =
86       convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
87                                  quantizedElementType, converter);
88   if (!quantDenseAttr) {
89     return nullptr;
90   }
91 
92   // Cast from an expressed-type-based type to storage-type-based type,
93   // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
94   ShapedType newSparseType =
95       quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
96           .dyn_cast_or_null<ShapedType>();
97   if (!newSparseType) {
98     return nullptr;
99   }
100   return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
101                                  quantDenseAttr);
102 }
103 
104 /// Converts a real expressed Attribute to a corresponding Attribute containing
105 /// quantized storage values assuming the given uniform quantizedElementType and
106 /// converter.
quantizeAttrUniform(Attribute realValue,quant::UniformQuantizedType quantizedElementType,const UniformQuantizedValueConverter & converter,Type & outConvertedType)107 Attribute mlir::quantfork::quantizeAttrUniform(
108     Attribute realValue, quant::UniformQuantizedType quantizedElementType,
109     const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
110   // Fork to handle different variants of constants supported.
111   if (realValue.isa<DenseFPElementsAttr>()) {
112     // Dense tensor or vector constant.
113     auto converted = convertDenseFPElementsAttr(
114         realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
115     outConvertedType = converted.getType();
116     return converted;
117   }
118   if (realValue.isa<SparseElementsAttr>()) {
119     // Sparse tensor or vector constant.
120     auto converted = convertSparseElementsAttr(
121         realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
122     outConvertedType = converted.getType();
123     return converted;
124   }
125   // Nothing else matched: try to convert a primitive.
126   return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
127                                    outConvertedType);
128 }
129 
130 /// Convert an attribute from a type based on
131 /// quantizedElementType.getExpressedType() to one based on
132 /// quantizedElementType.getStorageType().
133 /// Returns nullptr if the conversion is not supported.
134 /// On success, stores the converted type in outConvertedType.
quantizeAttr(Attribute realValue,quant::QuantizedType quantizedElementType,Type & outConvertedType)135 Attribute mlir::quantfork::quantizeAttr(
136     Attribute realValue, quant::QuantizedType quantizedElementType,
137     Type &outConvertedType) {
138   if (auto uniformQuantized =
139           quantizedElementType.dyn_cast<quant::UniformQuantizedType>()) {
140     UniformQuantizedValueConverter converter(uniformQuantized);
141     return quantizeAttrUniform(realValue, uniformQuantized, converter,
142                                outConvertedType);
143   }
144   if (auto uniformQuantizedPerAxis =
145           quantizedElementType.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
146     UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
147     auto converted = converter.convert(realValue);
148     // TODO: why we need this outConvertedType? remove it?
149     if (converted) {
150       outConvertedType = converted.getType();
151     }
152     return converted;
153   }
154   return nullptr;
155 }
156