1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Quant/Passes.h"
11 #include "mlir/Dialect/Quant/QuantOps.h"
12 #include "mlir/Dialect/Quant/QuantizeUtils.h"
13 #include "mlir/Dialect/Quant/UniformSupport.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19 using namespace mlir;
20 using namespace mlir::quant;
21
22 namespace {
23 struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> {
24 void runOnFunction() override;
25 };
26
27 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
28 using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
29
30 LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
31 PatternRewriter &rewriter) const override;
32 };
33
34 } // end anonymous namespace
35
36 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
37 /// quantized and the operand type is quantizable.
38
39 LogicalResult
matchAndRewrite(QuantizeCastOp qbarrier,PatternRewriter & rewriter) const40 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
41 PatternRewriter &rewriter) const {
42 Attribute value;
43
44 // Is the operand a constant?
45 if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
46 return failure();
47 }
48
49 // Does the qbarrier convert to a quantized type. This will not be true
50 // if a quantized type has not yet been chosen or if the cast to an equivalent
51 // storage type is not supported.
52 Type qbarrierResultType = qbarrier.getResult().getType();
53 QuantizedType quantizedElementType =
54 QuantizedType::getQuantizedElementType(qbarrierResultType);
55 if (!quantizedElementType) {
56 return failure();
57 }
58 if (!QuantizedType::castToStorageType(qbarrierResultType)) {
59 return failure();
60 }
61
62 // Is the operand type compatible with the expressed type of the quantized
63 // type? This will not be true if the qbarrier is superfluous (converts
64 // from and to a quantized type).
65 if (!quantizedElementType.isCompatibleExpressedType(
66 qbarrier.arg().getType())) {
67 return failure();
68 }
69
70 // Is the constant value a type expressed in a way that we support?
71 if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
72 return failure();
73 }
74
75 Type newConstValueType;
76 auto newConstValue =
77 quantizeAttr(value, quantizedElementType, newConstValueType);
78 if (!newConstValue) {
79 return failure();
80 }
81
82 // When creating the new const op, use a fused location that combines the
83 // original const and the qbarrier that led to the quantization.
84 auto fusedLoc = FusedLoc::get(
85 {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
86 rewriter.getContext());
87 auto newConstOp =
88 rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
89 rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
90 newConstOp);
91 return success();
92 }
93
runOnFunction()94 void ConvertConstPass::runOnFunction() {
95 OwningRewritePatternList patterns;
96 auto func = getFunction();
97 auto *context = &getContext();
98 patterns.insert<QuantizedConstRewrite>(context);
99 applyPatternsAndFoldGreedily(func, std::move(patterns));
100 }
101
createConvertConstPass()102 std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() {
103 return std::make_unique<ConvertConstPass>();
104 }
105