1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21
22 using namespace mlir;
23 using namespace quant;
24
parseStorageType(DialectAsmParser & parser,bool & isSigned)25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26 auto typeLoc = parser.getCurrentLocation();
27 IntegerType type;
28
29 // Parse storage type (alpha_ident, integer_literal).
30 StringRef identifier;
31 unsigned storageTypeWidth = 0;
32 if (failed(parser.parseOptionalKeyword(&identifier))) {
33 // If we didn't parse a keyword, this must be a signed type.
34 if (parser.parseType(type))
35 return nullptr;
36 isSigned = true;
37 storageTypeWidth = type.getWidth();
38
39 // Otherwise, this must be an unsigned integer (`u` integer-literal).
40 } else {
41 if (!identifier.consume_front("u")) {
42 parser.emitError(typeLoc, "illegal storage type prefix");
43 return nullptr;
44 }
45 if (identifier.getAsInteger(10, storageTypeWidth)) {
46 parser.emitError(typeLoc, "expected storage type width");
47 return nullptr;
48 }
49 isSigned = false;
50 type = parser.getBuilder().getIntegerType(storageTypeWidth);
51 }
52
53 if (storageTypeWidth == 0 ||
54 storageTypeWidth > QuantizedType::MaxStorageBits) {
55 parser.emitError(typeLoc, "illegal storage type size: ")
56 << storageTypeWidth;
57 return nullptr;
58 }
59
60 return type;
61 }
62
parseStorageRange(DialectAsmParser & parser,IntegerType storageType,bool isSigned,int64_t & storageTypeMin,int64_t & storageTypeMax)63 static ParseResult parseStorageRange(DialectAsmParser &parser,
64 IntegerType storageType, bool isSigned,
65 int64_t &storageTypeMin,
66 int64_t &storageTypeMax) {
67 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68 isSigned, storageType.getWidth());
69 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70 isSigned, storageType.getWidth());
71 if (failed(parser.parseOptionalLess())) {
72 storageTypeMin = defaultIntegerMin;
73 storageTypeMax = defaultIntegerMax;
74 return success();
75 }
76
77 // Explicit storage min and storage max.
78 llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
80 parser.getCurrentLocation(&maxLoc) ||
81 parser.parseInteger(storageTypeMax) || parser.parseGreater())
82 return failure();
83 if (storageTypeMin < defaultIntegerMin) {
84 return parser.emitError(minLoc, "illegal storage type minimum: ")
85 << storageTypeMin;
86 }
87 if (storageTypeMax > defaultIntegerMax) {
88 return parser.emitError(maxLoc, "illegal storage type maximum: ")
89 << storageTypeMax;
90 }
91 return success();
92 }
93
parseExpressedTypeAndRange(DialectAsmParser & parser,double & min,double & max)94 static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
95 double &min, double &max) {
96 auto typeLoc = parser.getCurrentLocation();
97 FloatType type;
98
99 if (failed(parser.parseType(type))) {
100 parser.emitError(typeLoc, "expecting float expressed type");
101 return nullptr;
102 }
103
104 // Calibrated min and max values.
105 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
106 parser.parseFloat(max) || parser.parseGreater()) {
107 parser.emitError(typeLoc, "calibrated values must be present");
108 return nullptr;
109 }
110 return type;
111 }
112
113 /// Parses an AnyQuantizedType.
114 ///
115 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
116 /// storage-spec ::= storage-type (`<` storage-range `>`)?
117 /// storage-range ::= integer-literal `:` integer-literal
118 /// storage-type ::= (`i` | `u`) integer-literal
119 /// expressed-type-spec ::= `:` `f` integer-literal
parseAnyType(DialectAsmParser & parser,Location loc)120 static Type parseAnyType(DialectAsmParser &parser, Location loc) {
121 IntegerType storageType;
122 FloatType expressedType;
123 unsigned typeFlags = 0;
124 int64_t storageTypeMin;
125 int64_t storageTypeMax;
126
127 // Type specification.
128 if (parser.parseLess())
129 return nullptr;
130
131 // Storage type.
132 bool isSigned = false;
133 storageType = parseStorageType(parser, isSigned);
134 if (!storageType) {
135 return nullptr;
136 }
137 if (isSigned) {
138 typeFlags |= QuantizationFlags::Signed;
139 }
140
141 // Storage type range.
142 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143 storageTypeMax)) {
144 return nullptr;
145 }
146
147 // Optional expressed type.
148 if (succeeded(parser.parseOptionalColon())) {
149 if (parser.parseType(expressedType)) {
150 return nullptr;
151 }
152 }
153
154 if (parser.parseGreater()) {
155 return nullptr;
156 }
157
158 return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
159 storageTypeMin, storageTypeMax, loc);
160 }
161
parseQuantParams(DialectAsmParser & parser,double & scale,int64_t & zeroPoint)162 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
163 int64_t &zeroPoint) {
164 // scale[:zeroPoint]?
165 // scale.
166 if (parser.parseFloat(scale))
167 return failure();
168
169 // zero point.
170 zeroPoint = 0;
171 if (failed(parser.parseOptionalColon())) {
172 // Default zero point.
173 return success();
174 }
175
176 return parser.parseInteger(zeroPoint);
177 }
178
179 /// Parses a UniformQuantizedType.
180 ///
181 /// uniform_type ::= uniform_per_layer
182 /// | uniform_per_axis
183 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
184 /// `,` scale-zero `>`
185 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
186 /// axis-spec `,` scale-zero-list `>`
187 /// storage-spec ::= storage-type (`<` storage-range `>`)?
188 /// storage-range ::= integer-literal `:` integer-literal
189 /// storage-type ::= (`i` | `u`) integer-literal
190 /// expressed-type-spec ::= `:` `f` integer-literal
191 /// axis-spec ::= `:` integer-literal
192 /// scale-zero ::= float-literal `:` integer-literal
193 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
parseUniformType(DialectAsmParser & parser,Location loc)194 static Type parseUniformType(DialectAsmParser &parser, Location loc) {
195 IntegerType storageType;
196 FloatType expressedType;
197 unsigned typeFlags = 0;
198 int64_t storageTypeMin;
199 int64_t storageTypeMax;
200 bool isPerAxis = false;
201 int32_t quantizedDimension;
202 SmallVector<double, 1> scales;
203 SmallVector<int64_t, 1> zeroPoints;
204
205 // Type specification.
206 if (parser.parseLess()) {
207 return nullptr;
208 }
209
210 // Storage type.
211 bool isSigned = false;
212 storageType = parseStorageType(parser, isSigned);
213 if (!storageType) {
214 return nullptr;
215 }
216 if (isSigned) {
217 typeFlags |= QuantizationFlags::Signed;
218 }
219
220 // Storage type range.
221 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
222 storageTypeMax)) {
223 return nullptr;
224 }
225
226 // Expressed type.
227 if (parser.parseColon() || parser.parseType(expressedType)) {
228 return nullptr;
229 }
230
231 // Optionally parse quantized dimension for per-axis quantization.
232 if (succeeded(parser.parseOptionalColon())) {
233 if (parser.parseInteger(quantizedDimension))
234 return nullptr;
235 isPerAxis = true;
236 }
237
238 // Comma leading into range_spec.
239 if (parser.parseComma()) {
240 return nullptr;
241 }
242
243 // Parameter specification.
244 // For per-axis, ranges are in a {} delimitted list.
245 if (isPerAxis) {
246 if (parser.parseLBrace()) {
247 return nullptr;
248 }
249 }
250
251 // Parse scales/zeroPoints.
252 llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
253 do {
254 scales.resize(scales.size() + 1);
255 zeroPoints.resize(zeroPoints.size() + 1);
256 if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
257 return nullptr;
258 }
259 } while (isPerAxis && succeeded(parser.parseOptionalComma()));
260
261 if (isPerAxis) {
262 if (parser.parseRBrace()) {
263 return nullptr;
264 }
265 }
266
267 if (parser.parseGreater()) {
268 return nullptr;
269 }
270
271 if (!isPerAxis && scales.size() > 1) {
272 return (parser.emitError(scaleZPLoc,
273 "multiple scales/zeroPoints provided, but "
274 "quantizedDimension wasn't specified"),
275 nullptr);
276 }
277
278 if (isPerAxis) {
279 ArrayRef<double> scalesRef(scales.begin(), scales.end());
280 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
281 return UniformQuantizedPerAxisType::getChecked(
282 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
283 quantizedDimension, storageTypeMin, storageTypeMax, loc);
284 }
285
286 return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType,
287 scales.front(), zeroPoints.front(),
288 storageTypeMin, storageTypeMax, loc);
289 }
290
291 /// Parses an CalibratedQuantizedType.
292 ///
293 /// calibrated ::= `calibrated<` expressed-spec `>`
294 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
295 /// expressed-type ::= `f` integer-literal
296 /// calibrated-range ::= float-literal `:` float-literal
parseCalibratedType(DialectAsmParser & parser,Location loc)297 static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
298 FloatType expressedType;
299 double min;
300 double max;
301
302 // Type specification.
303 if (parser.parseLess())
304 return nullptr;
305
306 // Expressed type.
307 expressedType = parseExpressedTypeAndRange(parser, min, max);
308 if (!expressedType) {
309 return nullptr;
310 }
311
312 if (parser.parseGreater()) {
313 return nullptr;
314 }
315
316 return CalibratedQuantizedType::getChecked(expressedType, min, max, loc);
317 }
318
319 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const320 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
321 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
322
323 // All types start with an identifier that we switch on.
324 StringRef typeNameSpelling;
325 if (failed(parser.parseKeyword(&typeNameSpelling)))
326 return nullptr;
327
328 if (typeNameSpelling == "uniform")
329 return parseUniformType(parser, loc);
330 if (typeNameSpelling == "any")
331 return parseAnyType(parser, loc);
332 if (typeNameSpelling == "calibrated")
333 return parseCalibratedType(parser, loc);
334
335 parser.emitError(parser.getNameLoc(),
336 "unknown quantized type " + typeNameSpelling);
337 return nullptr;
338 }
339
printStorageType(QuantizedType type,DialectAsmPrinter & out)340 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
341 // storage type
342 unsigned storageWidth = type.getStorageTypeIntegralWidth();
343 bool isSigned = type.isSigned();
344 if (isSigned) {
345 out << "i" << storageWidth;
346 } else {
347 out << "u" << storageWidth;
348 }
349
350 // storageTypeMin and storageTypeMax if not default.
351 int64_t defaultIntegerMin =
352 QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
353 int64_t defaultIntegerMax =
354 QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
355 if (defaultIntegerMin != type.getStorageTypeMin() ||
356 defaultIntegerMax != type.getStorageTypeMax()) {
357 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
358 << ">";
359 }
360 }
361
printQuantParams(double scale,int64_t zeroPoint,DialectAsmPrinter & out)362 static void printQuantParams(double scale, int64_t zeroPoint,
363 DialectAsmPrinter &out) {
364 out << scale;
365 if (zeroPoint != 0) {
366 out << ":" << zeroPoint;
367 }
368 }
369
370 /// Helper that prints a AnyQuantizedType.
printAnyQuantizedType(AnyQuantizedType type,DialectAsmPrinter & out)371 static void printAnyQuantizedType(AnyQuantizedType type,
372 DialectAsmPrinter &out) {
373 out << "any<";
374 printStorageType(type, out);
375 if (Type expressedType = type.getExpressedType()) {
376 out << ":" << expressedType;
377 }
378 out << ">";
379 }
380
381 /// Helper that prints a UniformQuantizedType.
printUniformQuantizedType(UniformQuantizedType type,DialectAsmPrinter & out)382 static void printUniformQuantizedType(UniformQuantizedType type,
383 DialectAsmPrinter &out) {
384 out << "uniform<";
385 printStorageType(type, out);
386 out << ":" << type.getExpressedType() << ", ";
387
388 // scheme specific parameters
389 printQuantParams(type.getScale(), type.getZeroPoint(), out);
390 out << ">";
391 }
392
393 /// Helper that prints a UniformQuantizedPerAxisType.
printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,DialectAsmPrinter & out)394 static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
395 DialectAsmPrinter &out) {
396 out << "uniform<";
397 printStorageType(type, out);
398 out << ":" << type.getExpressedType() << ":";
399 out << type.getQuantizedDimension();
400 out << ", ";
401
402 // scheme specific parameters
403 ArrayRef<double> scales = type.getScales();
404 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
405 out << "{";
406 llvm::interleave(
407 llvm::seq<size_t>(0, scales.size()), out,
408 [&](size_t index) {
409 printQuantParams(scales[index], zeroPoints[index], out);
410 },
411 ",");
412 out << "}>";
413 }
414
415 /// Helper that prints a CalibratedQuantizedType.
printCalibratedQuantizedType(CalibratedQuantizedType type,DialectAsmPrinter & out)416 static void printCalibratedQuantizedType(CalibratedQuantizedType type,
417 DialectAsmPrinter &out) {
418 out << "calibrated<" << type.getExpressedType();
419 out << "<" << type.getMin() << ":" << type.getMax() << ">";
420 out << ">";
421 }
422
423 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const424 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
425 if (auto anyType = type.dyn_cast<AnyQuantizedType>())
426 printAnyQuantizedType(anyType, os);
427 else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
428 printUniformQuantizedType(uniformType, os);
429 else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
430 printUniformQuantizedPerAxisType(perAxisType, os);
431 else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
432 printCalibratedQuantizedType(calibratedType, os);
433 else
434 llvm_unreachable("Unhandled quantized type");
435 }
436