1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies some clean up steps after quantization.
17
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/MLIRContext.h" // from @llvm-project
20 #include "mlir/Pass/Pass.h" // from @llvm-project
21 #include "mlir/Support/LogicalResult.h" // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
24 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26
27 //===----------------------------------------------------------------------===//
28 // The post-quantize Pass.
29 //
30 namespace mlir {
31 namespace TFL {
32 namespace {
33
34 // Applies all the clean up steps after quantization.
35 class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
36 public:
37 // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()38 explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
39
40 // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops)41 explicit PostQuantizePass(bool emit_quant_adaptor_ops)
42 : emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
43
44 void runOnFunction() override;
45
46 private:
47 // Set this flag to true if the inputs and outputs are in floating point. The
48 // quant adaptor ops convert them to fixed point values (i.e. quantize) before
49 // feeding them to the model and convert them back to floating point
50 // (i.e. dequantize) as the output.
51 bool emit_quant_adaptor_ops_;
52 };
53
RemoveQuantizationAdaptorOps(FuncOp func)54 void RemoveQuantizationAdaptorOps(FuncOp func) {
55 mlir::OpBuilder builder(func.getBody());
56 auto& bb = func.front();
57
58 int num_args = bb.getNumArguments();
59 llvm::SmallVector<Type, 4> input_types;
60 input_types.reserve(num_args);
61 // Edit the block arguments and create the new input ops in place to replace
62 // the old input ops and quantize ops.
63 for (int i = 0; i != num_args; ++i) {
64 // Previous loop iteration may invalidate the insertion point so we have to
65 // reset insertion point each iteration.
66 builder.setInsertionPointToStart(&bb);
67
68 // In each iteration, a new argument is appended to the end of the list
69 // and the current argument is erased, so here we always process the first
70 // argument in the list.
71 auto arg = bb.getArgument(0);
72
73 auto remove_quantize_op = [&](QuantizeOp quantize_op) {
74 auto quantize_output = quantize_op.output();
75 auto quantize_type = quantize_output.getType();
76 input_types.push_back(quantize_type);
77 auto new_arg = bb.addArgument(quantize_type);
78 quantize_output.replaceAllUsesWith(new_arg);
79 quantize_op.erase();
80 arg.dropAllUses();
81 bb.eraseArgument(0);
82 };
83
84 // This is looking for a pattern: arg -> tfl.quantize
85 if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
86 auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
87 remove_quantize_op(quantize_op);
88 continue;
89 }
90
91 // Make a copy of current argument and append it to the end of the list if
92 // the pattern isn't found.
93 Type arg_type = arg.getType();
94 input_types.push_back(arg_type);
95 auto new_arg = bb.addArgument(arg_type);
96 arg.replaceAllUsesWith(new_arg);
97 arg.dropAllUses();
98 bb.eraseArgument(0);
99 }
100
101 // Edit the return ops and remove the dequantize ops in place.
102 auto* terminator = bb.getTerminator();
103 int num_return_operands = terminator->getNumOperands();
104 llvm::SmallVector<Type, 4> output_types;
105 output_types.reserve(num_return_operands);
106 for (int i = 0; i != num_return_operands; ++i) {
107 auto returned_value = terminator->getOperand(i);
108 Operation* returned_op = returned_value.getDefiningOp();
109 if (returned_op && returned_op->hasOneUse() &&
110 llvm::isa<DequantizeOp>(returned_op)) {
111 auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
112 Value dequantized_result = dequantize_op.input();
113 output_types.push_back(dequantized_result.getType());
114 terminator->setOperand(i, dequantized_result);
115 returned_op->erase();
116 } else {
117 output_types.push_back(returned_value.getType());
118 }
119 }
120 auto new_func_type = builder.getFunctionType(input_types, output_types);
121 func.setType(new_func_type);
122 }
123
124 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
125 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anonb308970e0111::RemoveVolatileOps126 explicit RemoveVolatileOps(MLIRContext* context)
127 : OpRewritePattern<DequantizeOp>(context, 1) {}
128
matchAndRewritemlir::TFL::__anonb308970e0111::RemoveVolatileOps129 LogicalResult matchAndRewrite(DequantizeOp op,
130 PatternRewriter& rewriter) const override {
131 auto input_op = op.input().getDefiningOp();
132 if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
133 if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
134
135 // Don't remove leading and tailing QDQ for PQT workflow, so the io
136 // modifying lib can work correctly.
137 if (!q.input().getDefiningOp()) return failure();
138 if (op->hasOneUse() &&
139 op->user_begin()->hasTrait<OpTrait::IsTerminator>())
140 return failure();
141
142 op.replaceAllUsesWith(q.input());
143 return success();
144 }
145 return failure();
146 }
147 };
148
149 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
150 // output.
151 template <typename OpTy>
152 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
153 public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anonb308970e0111::PruneUnusedOpsWithSideEffect154 explicit PruneUnusedOpsWithSideEffect(MLIRContext* context)
155 : OpRewritePattern<OpTy>(context) {}
156
matchAndRewritemlir::TFL::__anonb308970e0111::PruneUnusedOpsWithSideEffect157 LogicalResult matchAndRewrite(OpTy op,
158 PatternRewriter& rewriter) const override {
159 if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
160 return failure();
161 }
162 for (auto result : op.getOperation()->getOpResults()) {
163 if (!result.use_empty()) {
164 return failure();
165 }
166 }
167 rewriter.eraseOp(op);
168 return success();
169 }
170 };
171
172 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
173
runOnFunction()174 void PostQuantizePass::runOnFunction() {
175 OwningRewritePatternList patterns;
176 auto func = getFunction();
177 auto* ctx = func.getContext();
178 TFL::populateWithGenerated(ctx, patterns);
179 patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
180 patterns.insert<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
181 patterns
182 .insert<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
183 ctx);
184 patterns.insert<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
185 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
186
187 if (!emit_quant_adaptor_ops_) {
188 RemoveQuantizationAdaptorOps(getFunction());
189 }
190
191 OwningRewritePatternList phase_2_patterns;
192 TFL::populateWithGenerated(ctx, phase_2_patterns);
193 phase_2_patterns
194 .insert<quant::FoldTrivalRequantizeOp<QuantizeOp>, RemoveVolatileOps>(
195 ctx);
196 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
197 }
198
199 } // namespace
200
201 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops)202 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
203 bool emit_quant_adaptor_ops) {
204 return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
205 }
206
207 static PassRegistration<PostQuantizePass> pass(
208 "tfl-post-quantize", "Apply post quantization clean up after quantization");
209
210 } // namespace TFL
211 } // namespace mlir
212