• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef TYPE_DETAIL_H_
10 #define TYPE_DETAIL_H_
11 
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/TypeSupport.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Hashing.h"
17 #include "llvm/ADT/bit.h"
18 
19 namespace mlir {
20 namespace quant {
21 namespace detail {
22 
23 struct QuantizedTypeStorage : public mlir::TypeStorage {
QuantizedTypeStorageQuantizedTypeStorage24   QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
25                        int64_t storageTypeMin, int64_t storageTypeMax)
26       : flags(flags), storageType(storageType), expressedType(expressedType),
27         storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
28 
29   /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
30   unsigned flags;
31 
32   // Integral type for the storage point representation.
33   Type storageType;
34 
35   // Floating point type that the quantized type approximates.
36   Type expressedType;
37 
38   // The minimum value storageType can take.
39   int64_t storageTypeMin;
40 
41   // The maximum value storageType can take.
42   int64_t storageTypeMax;
43 };
44 
45 struct AnyQuantizedTypeStorage : public QuantizedTypeStorage {
46   struct KeyTy {
KeyTyAnyQuantizedTypeStorage::KeyTy47     KeyTy(unsigned flags, Type storageType, Type expressedType,
48           int64_t storageTypeMin, int64_t storageTypeMax)
49         : flags(flags), storageType(storageType), expressedType(expressedType),
50           storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
51     unsigned flags;
52     Type storageType;
53     Type expressedType;
54     int64_t storageTypeMin;
55     int64_t storageTypeMax;
56 
57     // Check for equality of two structures that share KeyTy data members
58     // (by name).
59     template <typename T, typename U>
genericIsEqualAnyQuantizedTypeStorage::KeyTy60     static bool genericIsEqual(const T &lhs, const U &rhs) {
61       return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
62              lhs.expressedType == rhs.expressedType &&
63              lhs.storageTypeMin == rhs.storageTypeMin &&
64              lhs.storageTypeMax == rhs.storageTypeMax;
65     }
66 
67     bool operator==(const KeyTy &other) const {
68       return genericIsEqual(*this, other);
69     }
70 
getHashValueAnyQuantizedTypeStorage::KeyTy71     unsigned getHashValue() const {
72       return llvm::hash_combine(flags, storageType, expressedType,
73                                 storageTypeMin, storageTypeMax);
74     }
75   };
76 
AnyQuantizedTypeStorageAnyQuantizedTypeStorage77   AnyQuantizedTypeStorage(const KeyTy &key)
78       : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
79                              key.storageTypeMin, key.storageTypeMax) {}
80 
81   bool operator==(const KeyTy &key) const {
82     return KeyTy::genericIsEqual(*this, key);
83   }
84 
85   /// Construction.
constructAnyQuantizedTypeStorage86   static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
87                                             const KeyTy &key) {
88     return new (allocator.allocate<AnyQuantizedTypeStorage>())
89         AnyQuantizedTypeStorage(key);
90   }
91 
hashKeyAnyQuantizedTypeStorage92   static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
93 };
94 
95 struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
96   struct KeyTy {
KeyTyUniformQuantizedTypeStorage::KeyTy97     KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
98           int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
99         : flags(flags), storageType(storageType), expressedType(expressedType),
100           scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
101           storageTypeMax(storageTypeMax) {}
102     /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
103     unsigned flags;
104 
105     // Integral type for the storage point representation.
106     Type storageType;
107 
108     // Floating point type that the quantized type approximates.
109     Type expressedType;
110 
111     double scale;
112     int64_t zeroPoint;
113     int64_t storageTypeMin;
114     int64_t storageTypeMax;
115 
116     // Check for equality of two structures that share KeyTy data members
117     // (by name).
118     template <typename T, typename U>
genericIsEqualUniformQuantizedTypeStorage::KeyTy119     static bool genericIsEqual(const T &lhs, const U &rhs) {
120       return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
121              lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
122              lhs.zeroPoint == rhs.zeroPoint &&
123              lhs.storageTypeMin == rhs.storageTypeMin &&
124              lhs.storageTypeMax == rhs.storageTypeMax;
125     }
126 
127     bool operator==(const KeyTy &other) const {
128       return genericIsEqual(*this, other);
129     }
130 
getHashValueUniformQuantizedTypeStorage::KeyTy131     unsigned getHashValue() const {
132       int64_t scaleBits = llvm::bit_cast<int64_t>(scale);
133       return llvm::hash_combine(flags, storageType, expressedType, scaleBits,
134                                 zeroPoint, storageTypeMin, storageTypeMax);
135     }
136   };
137 
UniformQuantizedTypeStorageUniformQuantizedTypeStorage138   UniformQuantizedTypeStorage(const KeyTy &key)
139       : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
140                              key.storageTypeMin, key.storageTypeMax),
141         scale(key.scale), zeroPoint(key.zeroPoint) {}
142 
143   bool operator==(const KeyTy &key) const {
144     return KeyTy::genericIsEqual(*this, key);
145   }
146 
147   /// Construction.
constructUniformQuantizedTypeStorage148   static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
149                                                 const KeyTy &key) {
150     return new (allocator.allocate<UniformQuantizedTypeStorage>())
151         UniformQuantizedTypeStorage(key);
152   }
153 
hashKeyUniformQuantizedTypeStorage154   static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
155 
156   double scale;
157   int64_t zeroPoint;
158 };
159 
160 struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
161   struct KeyTy {
KeyTyUniformQuantizedPerAxisTypeStorage::KeyTy162     KeyTy(unsigned flags, Type storageType, Type expressedType,
163           ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
164           int32_t quantizedDimension, int64_t storageTypeMin,
165           int64_t storageTypeMax)
166         : flags(flags), storageType(storageType), expressedType(expressedType),
167           scales(scales), zeroPoints(zeroPoints),
168           quantizedDimension(quantizedDimension),
169           storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
170     /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
171     unsigned flags;
172 
173     // Integral type for the storage point representation.
174     Type storageType;
175 
176     // Floating point type that the quantized type approximates.
177     Type expressedType;
178 
179     ArrayRef<double> scales;
180     ArrayRef<int64_t> zeroPoints;
181     int32_t quantizedDimension;
182     int64_t storageTypeMin;
183     int64_t storageTypeMax;
184 
getScalesUniformQuantizedPerAxisTypeStorage::KeyTy185     ArrayRef<double> getScales() const { return scales; }
186 
getZeroPointsUniformQuantizedPerAxisTypeStorage::KeyTy187     ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
188 
189     // Check for equality of two structures that share KeyTy data members
190     // (by name).
191     template <typename T, typename U>
genericIsEqualUniformQuantizedPerAxisTypeStorage::KeyTy192     static bool genericIsEqual(const T &lhs, const U &rhs) {
193       return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
194              lhs.expressedType == rhs.expressedType &&
195              lhs.getScales() == rhs.getScales() &&
196              lhs.getZeroPoints() == rhs.getZeroPoints() &&
197              lhs.quantizedDimension == rhs.quantizedDimension &&
198              lhs.storageTypeMin == rhs.storageTypeMin &&
199              lhs.storageTypeMax == rhs.storageTypeMax;
200     }
201 
202     bool operator==(const KeyTy &other) const {
203       return genericIsEqual(*this, other);
204     }
205 
getHashValueUniformQuantizedPerAxisTypeStorage::KeyTy206     unsigned getHashValue() const {
207       int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data());
208       ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
209       return llvm::hash_combine(
210           flags, storageType, expressedType,
211           llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()),
212           llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()),
213           storageTypeMin, storageTypeMax);
214     }
215   };
216 
217   // We pass scales and zeroPoints in directly rather than relying on KeyTy
218   // because we have to create new reallocated versions in `construct` below.
UniformQuantizedPerAxisTypeStorageUniformQuantizedPerAxisTypeStorage219   UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
220                                      ArrayRef<int64_t> zeroPoints)
221       : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
222                              key.storageTypeMin, key.storageTypeMax),
223         scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
224         quantParamsSize(scales.size()),
225         quantizedDimension(key.quantizedDimension) {}
226 
227   bool operator==(const KeyTy &key) const {
228     return KeyTy::genericIsEqual(*this, key);
229   }
230 
231   /// Construction.
232   static UniformQuantizedPerAxisTypeStorage *
constructUniformQuantizedPerAxisTypeStorage233   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
234     ArrayRef<double> scales = allocator.copyInto(key.scales);
235     ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints);
236     return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
237         UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
238   }
239 
hashKeyUniformQuantizedPerAxisTypeStorage240   static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
241 
getScalesUniformQuantizedPerAxisTypeStorage242   ArrayRef<double> getScales() const {
243     return ArrayRef<double>(scaleElements, quantParamsSize);
244   }
245 
getZeroPointsUniformQuantizedPerAxisTypeStorage246   ArrayRef<int64_t> getZeroPoints() const {
247     return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
248   }
249 
250   const double *scaleElements;
251   const int64_t *zeroPointElements;
252   unsigned quantParamsSize;
253   int32_t quantizedDimension;
254 };
255 
256 struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
257   struct KeyTy {
KeyTyCalibratedQuantizedTypeStorage::KeyTy258     KeyTy(Type expressedType, double min, double max)
259         : expressedType(expressedType), min(min), max(max) {}
260     // Floating point type that the quantized type approximates.
261     Type expressedType;
262 
263     double min;
264     double max;
265 
266     // Check for equality of two structures that share KeyTy data members
267     // (by name).
268     template <typename T, typename U>
genericIsEqualCalibratedQuantizedTypeStorage::KeyTy269     static bool genericIsEqual(const T &lhs, const U &rhs) {
270       return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min &&
271              lhs.max == rhs.max;
272     }
273 
274     bool operator==(const KeyTy &other) const {
275       return genericIsEqual(*this, other);
276     }
277 
getHashValueCalibratedQuantizedTypeStorage::KeyTy278     unsigned getHashValue() const {
279       int64_t minBits = llvm::bit_cast<double>(min);
280       int64_t maxBits = llvm::bit_cast<double>(max);
281       return llvm::hash_combine(expressedType, minBits, maxBits);
282     }
283   };
284 
CalibratedQuantizedTypeStorageCalibratedQuantizedTypeStorage285   CalibratedQuantizedTypeStorage(const KeyTy &key)
286       : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0),
287         min(key.min), max(key.max) {}
288 
289   bool operator==(const KeyTy &key) const {
290     return KeyTy::genericIsEqual(*this, key);
291   }
292 
293   /// Construction.
294   static CalibratedQuantizedTypeStorage *
constructCalibratedQuantizedTypeStorage295   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
296     return new (allocator.allocate<CalibratedQuantizedTypeStorage>())
297         CalibratedQuantizedTypeStorage(key);
298   }
299 
hashKeyCalibratedQuantizedTypeStorage300   static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
301 
302   double min;
303   double max;
304 };
305 
306 } // namespace detail
307 } // namespace quant
308 } // namespace mlir
309 
310 #endif // TYPE_DETAIL_H_
311