• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass decomposes dense operations that assume
17 // support for hybrid quantization. These cases cover when a dense operation
18 // (e.g. matmul) has both quantized and unquantized inputs by dequantizing
19 // the quantized inputs, performing the operation in the expressed type, then
20 // requantizing if a quantized output is required.
21 //
22 // The motivation behind these changes is for Dialects that assume only float
23 // or quantized computation, and do not support a mixture of these types on
24 // dense operations. Decomposition allows TFLite to be compiled to these
25 // dialects, such as TOSA.
26 
27 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
33 
34 namespace mlir {
35 namespace TFL {
36 
37 namespace {
38 
39 class DecomposeHybridQuantizationPass
40     : public PassWrapper<DecomposeHybridQuantizationPass, FunctionPass> {
41  public:
42   DecomposeHybridQuantizationPass() = default;
DecomposeHybridQuantizationPass(const DecomposeHybridQuantizationPass &)43   DecomposeHybridQuantizationPass(const DecomposeHybridQuantizationPass &) {}
44 
getArgument() const45   StringRef getArgument() const override {
46     return "tfl-decompose-hybrid-quantization";
47   }
48 
getDescription() const49   StringRef getDescription() const override {
50     return "Decomposes (with explicit quantize/dequantize ops) selected math "
51            "operations which exist in the model with hybrid quantization "
52            "(some arguments/results left in floating point).";
53   }
54 
55   void runOnFunction() override;
56 
getDependentDialects(DialectRegistry & registry) const57   void getDependentDialects(DialectRegistry &registry) const override {
58     registry.insert<TFL::TensorFlowLiteDialect>();
59   }
60 };
61 
62 template <typename SrcOp>
63 class DequantizeConverter : public OpRewritePattern<SrcOp> {
64  public:
65   using OpRewritePattern<SrcOp>::OpRewritePattern;
66 
matchAndRewrite(SrcOp srcop,PatternRewriter & rewriter) const67   LogicalResult matchAndRewrite(SrcOp srcop,
68                                 PatternRewriter &rewriter) const final {
69     Operation *op = srcop.getOperation();
70     bool allTypesFp = true;
71     bool allTypesQuantized = true;
72     for (auto operand : op->getOperands()) {
73       ShapedType type = operand.getType().template dyn_cast<ShapedType>();
74       if (!type) continue;
75       allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
76       allTypesQuantized &= type.getElementType().isa<quant::QuantizedType>();
77     }
78 
79     for (auto result : op->getResults()) {
80       ShapedType type = result.getType().template cast<ShapedType>();
81       allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
82       allTypesQuantized &= type.getElementType().isa<quant::QuantizedType>();
83     }
84 
85     // If all quantized or floating point then types are consistent.
86     if (allTypesFp || allTypesQuantized) return failure();
87 
88     Location loc = op->getLoc();
89     SmallVector<Value> newOperands;
90     newOperands.reserve(op->getNumOperands());
91     for (auto operand : op->getOperands()) {
92       if (QuantizedType::getQuantizedElementType(operand.getType())) {
93         auto newTy = QuantizedType::castToExpressedType(operand.getType());
94         newOperands.push_back(
95             rewriter.create<TFL::DequantizeOp>(loc, newTy, operand));
96         continue;
97       }
98 
99       newOperands.push_back(operand);
100     }
101 
102     SmallVector<Type> newResultTys;
103     for (auto result : op->getResults()) {
104       Type resultTy = result.getType();
105       if (QuantizedType::getQuantizedElementType(resultTy)) {
106         resultTy = QuantizedType::castToExpressedType(resultTy);
107       }
108       newResultTys.push_back(resultTy);
109     }
110 
111     auto newResults = rewriter
112                           .create<SrcOp>(loc, newResultTys, newOperands,
113                                          op->getAttrDictionary().getValue())
114                           .getOperation()
115                           ->getResults();
116 
117     SmallVector<Value> replaceResults;
118     for (int i = 0; i < newResults.size(); i++) {
119       Value result = newResults[i];
120       Type resultTy = op->getOpResult(i).getType();
121       if (QuantizedType::getQuantizedElementType(resultTy)) {
122         replaceResults.push_back(rewriter.create<TFL::QuantizeOp>(
123             loc, resultTy, result, TypeAttr::get(resultTy)));
124         continue;
125       }
126 
127       replaceResults.push_back(result);
128     }
129 
130     rewriter.replaceOp(op, replaceResults);
131 
132     return success();
133   }
134 };
135 
runOnFunction()136 void DecomposeHybridQuantizationPass::runOnFunction() {
137   OwningRewritePatternList patterns(&getContext());
138   auto *ctx = &getContext();
139   auto func = getFunction();
140   patterns.insert<DequantizeConverter<TFL::Conv2DOp>,
141                   DequantizeConverter<TFL::Conv3DOp>,
142                   DequantizeConverter<TFL::DepthwiseConv2DOp>,
143                   DequantizeConverter<TFL::FullyConnectedOp>,
144                   DequantizeConverter<TFL::TransposeConvOp>>(ctx);
145   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
146 }
147 
148 }  // namespace
149 
CreateDecomposeHybridQuantizationPass()150 std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeHybridQuantizationPass() {
151   return std::make_unique<DecomposeHybridQuantizationPass>();
152 }
153 
154 static PassRegistration<DecomposeHybridQuantizationPass> pass;
155 
156 }  // namespace TFL
157 }  // namespace mlir
158