1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass decomposes dense operations that assume
17 // support for hybrid quantization. These cases cover when a dense operation
18 // (e.g. matmul) has both quantized and unquantized inputs by dequantizing
19 // the quantized inputs, performing the operation in the expressed type, then
20 // requantizing if a quantized output is required.
21 //
22 // The motivation behind these changes is for Dialects that assume only float
23 // or quantized computation, and do not support a mixture of these types on
24 // dense operations. Decomposition allows TFLite to be compiled to these
25 // dialects, such as TOSA.
26
27 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
33
34 namespace mlir {
35 namespace TFL {
36
37 namespace {
38
39 class DecomposeHybridQuantizationPass
40 : public PassWrapper<DecomposeHybridQuantizationPass, FunctionPass> {
41 public:
42 DecomposeHybridQuantizationPass() = default;
DecomposeHybridQuantizationPass(const DecomposeHybridQuantizationPass &)43 DecomposeHybridQuantizationPass(const DecomposeHybridQuantizationPass &) {}
44
getArgument() const45 StringRef getArgument() const override {
46 return "tfl-decompose-hybrid-quantization";
47 }
48
getDescription() const49 StringRef getDescription() const override {
50 return "Decomposes (with explicit quantize/dequantize ops) selected math "
51 "operations which exist in the model with hybrid quantization "
52 "(some arguments/results left in floating point).";
53 }
54
55 void runOnFunction() override;
56
getDependentDialects(DialectRegistry & registry) const57 void getDependentDialects(DialectRegistry ®istry) const override {
58 registry.insert<TFL::TensorFlowLiteDialect>();
59 }
60 };
61
62 template <typename SrcOp>
63 class DequantizeConverter : public OpRewritePattern<SrcOp> {
64 public:
65 using OpRewritePattern<SrcOp>::OpRewritePattern;
66
matchAndRewrite(SrcOp srcop,PatternRewriter & rewriter) const67 LogicalResult matchAndRewrite(SrcOp srcop,
68 PatternRewriter &rewriter) const final {
69 Operation *op = srcop.getOperation();
70 bool allTypesFp = true;
71 bool allTypesQuantized = true;
72 for (auto operand : op->getOperands()) {
73 ShapedType type = operand.getType().template dyn_cast<ShapedType>();
74 if (!type) continue;
75 allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
76 allTypesQuantized &= type.getElementType().isa<quant::QuantizedType>();
77 }
78
79 for (auto result : op->getResults()) {
80 ShapedType type = result.getType().template cast<ShapedType>();
81 allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
82 allTypesQuantized &= type.getElementType().isa<quant::QuantizedType>();
83 }
84
85 // If all quantized or floating point then types are consistent.
86 if (allTypesFp || allTypesQuantized) return failure();
87
88 Location loc = op->getLoc();
89 SmallVector<Value> newOperands;
90 newOperands.reserve(op->getNumOperands());
91 for (auto operand : op->getOperands()) {
92 if (QuantizedType::getQuantizedElementType(operand.getType())) {
93 auto newTy = QuantizedType::castToExpressedType(operand.getType());
94 newOperands.push_back(
95 rewriter.create<TFL::DequantizeOp>(loc, newTy, operand));
96 continue;
97 }
98
99 newOperands.push_back(operand);
100 }
101
102 SmallVector<Type> newResultTys;
103 for (auto result : op->getResults()) {
104 Type resultTy = result.getType();
105 if (QuantizedType::getQuantizedElementType(resultTy)) {
106 resultTy = QuantizedType::castToExpressedType(resultTy);
107 }
108 newResultTys.push_back(resultTy);
109 }
110
111 auto newResults = rewriter
112 .create<SrcOp>(loc, newResultTys, newOperands,
113 op->getAttrDictionary().getValue())
114 .getOperation()
115 ->getResults();
116
117 SmallVector<Value> replaceResults;
118 for (int i = 0; i < newResults.size(); i++) {
119 Value result = newResults[i];
120 Type resultTy = op->getOpResult(i).getType();
121 if (QuantizedType::getQuantizedElementType(resultTy)) {
122 replaceResults.push_back(rewriter.create<TFL::QuantizeOp>(
123 loc, resultTy, result, TypeAttr::get(resultTy)));
124 continue;
125 }
126
127 replaceResults.push_back(result);
128 }
129
130 rewriter.replaceOp(op, replaceResults);
131
132 return success();
133 }
134 };
135
runOnFunction()136 void DecomposeHybridQuantizationPass::runOnFunction() {
137 OwningRewritePatternList patterns(&getContext());
138 auto *ctx = &getContext();
139 auto func = getFunction();
140 patterns.insert<DequantizeConverter<TFL::Conv2DOp>,
141 DequantizeConverter<TFL::Conv3DOp>,
142 DequantizeConverter<TFL::DepthwiseConv2DOp>,
143 DequantizeConverter<TFL::FullyConnectedOp>,
144 DequantizeConverter<TFL::TransposeConvOp>>(ctx);
145 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
146 }
147
148 } // namespace
149
CreateDecomposeHybridQuantizationPass()150 std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeHybridQuantizationPass() {
151 return std::make_unique<DecomposeHybridQuantizationPass>();
152 }
153
154 static PassRegistration<DecomposeHybridQuantizationPass> pass;
155
156 } // namespace TFL
157 } // namespace mlir
158