1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "TypeDetail.h"
11
12 #include "mlir/Dialect/Quant/QuantTypes.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <numeric>
21
22 using namespace mlir;
23 using namespace mlir::quant;
24 using namespace mlir::quant::detail;
25
initialize()26 void QuantizationDialect::initialize() {
27 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
28 UniformQuantizedPerAxisType>();
29 addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
32 >();
33 }
34
fold(ArrayRef<Attribute> operands)35 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
36 // Matches x -> [scast -> scast] -> y, replacing the second scast with the
37 // value of x if the casts invert each other.
38 auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
39 if (!srcScastOp || srcScastOp.arg().getType() != getType())
40 return OpFoldResult();
41 return srcScastOp.arg();
42 }
43
44 /// The quantization specification should match the expressed type.
isValidQuantizationSpec(Attribute quantSpec,Type expressed)45 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
46 if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
47 Type spec = typeAttr.getValue();
48 if (spec.isa<TensorType, VectorType>())
49 return false;
50
51 // The spec should be either a quantized type which is compatible to the
52 // expressed type, or a primitive type which is as same as the
53 // (element type of) the expressed type.
54 if (auto quantizedType = spec.dyn_cast<QuantizedType>())
55 return quantizedType.isCompatibleExpressedType(expressed);
56
57 if (auto tensorType = expressed.dyn_cast<TensorType>())
58 return spec == tensorType.getElementType();
59
60 if (auto vectorType = expressed.dyn_cast<VectorType>())
61 return spec == vectorType.getElementType();
62 }
63 return false;
64 }
65
verifyRegionOp(QuantizeRegionOp op)66 static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
67 // There are specifications for both inputs and outputs.
68 if (op.getNumOperands() != op.input_specs().size() ||
69 op.getNumResults() != op.output_specs().size())
70 return op.emitOpError(
71 "has unmatched operands/results number and spec attributes number");
72
73 // Verify that quantization specifications are valid.
74 for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
75 Type inputType = std::get<0>(input);
76 Attribute inputSpec = std::get<1>(input);
77 if (!isValidQuantizationSpec(inputSpec, inputType)) {
78 return op.emitOpError() << "has incompatible specification " << inputSpec
79 << " and input type " << inputType;
80 }
81 }
82
83 for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
84 Type outputType = std::get<0>(result);
85 Attribute outputSpec = std::get<1>(result);
86 if (!isValidQuantizationSpec(outputSpec, outputType)) {
87 return op.emitOpError() << "has incompatible specification " << outputSpec
88 << " and output type " << outputType;
89 }
90 }
91 return success();
92 }
93
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
96