1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies some clean up steps after quantization.
17
18 #include <utility>
19
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/MLIRContext.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Support/LogicalResult.h" // from @llvm-project
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
27 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
28
29 //===----------------------------------------------------------------------===//
30 // The post-quantize Passes.
31 //
32 namespace mlir {
33 namespace TFL {
34 namespace {
35
36 // Applies all the clean up steps after quantization.
37 class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
38 public:
39 // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()40 explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
41
42 // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops)43 explicit PostQuantizePass(bool emit_quant_adaptor_ops)
44 : emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
45
getArgument() const46 StringRef getArgument() const final {
47 // This is the argument used to refer to the pass in
48 // the textual format (on the commandline for example).
49 return "tfl-post-quantize";
50 }
getDescription() const51 StringRef getDescription() const final {
52 // This is a brief description of the pass.
53 return "Apply post quantization clean up after quantization";
54 }
55
56 void runOnFunction() override;
57
58 private:
59 // Set this flag to true if the inputs and outputs are in floating point. The
60 // quant adaptor ops convert them to fixed point values (i.e. quantize) before
61 // feeding them to the model and convert them back to floating point
62 // (i.e. dequantize) as the output.
63 bool emit_quant_adaptor_ops_;
64 };
65
66 // Cleans up unnecessary QDQ pattern for input/output ops.
67 class PostQuantizeRemoveQDQPass
68 : public PassWrapper<PostQuantizeRemoveQDQPass, FunctionPass> {
69 public:
70 // Constructor used by the PassRegistration. This will remove QDQ ops.
PostQuantizeRemoveQDQPass()71 explicit PostQuantizeRemoveQDQPass() {}
72
getArgument() const73 StringRef getArgument() const final {
74 // This is the argument used to refer to the pass in
75 // the textual format (on the commandline for example).
76 return "tfl-post-quantize-remove-qdq";
77 }
getDescription() const78 StringRef getDescription() const final {
79 // This is a brief description of the pass.
80 return "Remove qdq from input and output nodes after quantization";
81 }
82
83 void runOnFunction() override;
84 };
85
86 // TODO(fengliuai): migrate to use modify_io_nodes pass.
RemoveQuantizationAdaptorOps(FuncOp func)87 void RemoveQuantizationAdaptorOps(FuncOp func) {
88 mlir::OpBuilder builder(func.getBody());
89 auto& bb = func.front();
90
91 int num_args = bb.getNumArguments();
92 llvm::SmallVector<Type, 4> input_types;
93 input_types.reserve(num_args);
94 // Edit the block arguments and create the new input ops in place to replace
95 // the old input ops and quantize ops.
96 for (int i = 0; i != num_args; ++i) {
97 // Previous loop iteration may invalidate the insertion point so we have to
98 // reset insertion point each iteration.
99 builder.setInsertionPointToStart(&bb);
100
101 // In each iteration, a new argument is appended to the end of the list
102 // and the current argument is erased, so here we always process the first
103 // argument in the list.
104 auto arg = bb.getArgument(0);
105
106 auto remove_quantize_op = [&](QuantizeOp quantize_op) {
107 auto quantize_output = quantize_op.output();
108 auto quantize_type = quantize_output.getType();
109 input_types.push_back(quantize_type);
110 auto new_arg = bb.addArgument(quantize_type);
111 quantize_output.replaceAllUsesWith(new_arg);
112 quantize_op.erase();
113 arg.dropAllUses();
114 bb.eraseArgument(0);
115 };
116
117 // This is looking for a pattern: arg -> tfl.quantize
118 if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
119 auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
120 remove_quantize_op(quantize_op);
121 continue;
122 }
123
124 // Make a copy of current argument and append it to the end of the list if
125 // the pattern isn't found.
126 Type arg_type = arg.getType();
127 input_types.push_back(arg_type);
128 auto new_arg = bb.addArgument(arg_type);
129 arg.replaceAllUsesWith(new_arg);
130 arg.dropAllUses();
131 bb.eraseArgument(0);
132 }
133
134 // Edit the return ops and remove the dequantize ops in place.
135 auto* terminator = bb.getTerminator();
136 int num_return_operands = terminator->getNumOperands();
137 llvm::SmallVector<Type, 4> output_types;
138 output_types.reserve(num_return_operands);
139 for (int i = 0; i != num_return_operands; ++i) {
140 auto returned_value = terminator->getOperand(i);
141 Operation* returned_op = returned_value.getDefiningOp();
142 if (returned_op && returned_op->hasOneUse() &&
143 llvm::isa<DequantizeOp>(returned_op)) {
144 auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
145 Value dequantized_result = dequantize_op.input();
146 output_types.push_back(dequantized_result.getType());
147 terminator->setOperand(i, dequantized_result);
148 returned_op->erase();
149 } else {
150 output_types.push_back(returned_value.getType());
151 }
152 }
153 auto new_func_type = builder.getFunctionType(input_types, output_types);
154 func.setType(new_func_type);
155 }
156
157 enum RemoveVolatileOpsType {
158 // Remove all volatile quant-dequant ops.
159 kPreserveNone,
160 // Preserve volatile quant-dequants for input and output ops.
161 kPreserveInputsAndOutputs,
162 };
163
164 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
165 template <RemoveVolatileOpsType remove_volatile_ops_type>
166 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anonbfc12d730111::RemoveVolatileOps167 explicit RemoveVolatileOps(MLIRContext* context)
168 : OpRewritePattern<DequantizeOp>(context, 1) {}
169
matchAndRewritemlir::TFL::__anonbfc12d730111::RemoveVolatileOps170 LogicalResult matchAndRewrite(DequantizeOp op,
171 PatternRewriter& rewriter) const override {
172 auto input_op = op.input().getDefiningOp();
173 if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
174 if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
175
176 if (remove_volatile_ops_type == kPreserveInputsAndOutputs) {
177 // Don't remove leading and tailing QDQ for PQT workflow, so the io
178 // modifying lib can work correctly.
179 if (!q.input().getDefiningOp()) return failure();
180 if (op->hasOneUse() &&
181 op->user_begin()->hasTrait<OpTrait::IsTerminator>())
182 return failure();
183 }
184 // If the quantize op is a requantize op, it is being used in other scale
185 // adjustments and should be kept. Instead, moving dequantize op before
186 // the requantize op to remove the unnecessary requantize op.
187 if (auto qtype = quant::QuantizedType::getQuantizedElementType(
188 q.input().getType())) {
189 rewriter.setInsertionPoint(op);
190 rewriter.replaceOpWithNewOp<DequantizeOp>(op, op.output().getType(),
191 q.input());
192 return success();
193 }
194
195 op.replaceAllUsesWith(q.input());
196 return success();
197 }
198 return failure();
199 }
200 };
201
202 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
203 // output.
204 template <typename OpTy>
205 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
206 public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anonbfc12d730111::PruneUnusedOpsWithSideEffect207 explicit PruneUnusedOpsWithSideEffect(MLIRContext* context)
208 : OpRewritePattern<OpTy>(context) {}
209
matchAndRewritemlir::TFL::__anonbfc12d730111::PruneUnusedOpsWithSideEffect210 LogicalResult matchAndRewrite(OpTy op,
211 PatternRewriter& rewriter) const override {
212 if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
213 return failure();
214 }
215 for (auto result : op.getOperation()->getOpResults()) {
216 if (!result.use_empty()) {
217 return failure();
218 }
219 }
220 rewriter.eraseOp(op);
221 return success();
222 }
223 };
224
225 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
226
runOnFunction()227 void PostQuantizePass::runOnFunction() {
228 OwningRewritePatternList patterns(&getContext());
229 auto func = getFunction();
230 auto* ctx = func.getContext();
231 TFL::populateWithGenerated(patterns);
232 patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
233 patterns.insert<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
234 patterns
235 .insert<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
236 ctx);
237 patterns.insert<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
238 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
239
240 if (!emit_quant_adaptor_ops_) {
241 RemoveQuantizationAdaptorOps(getFunction());
242 }
243
244 OwningRewritePatternList phase_2_patterns(&getContext());
245 TFL::populateWithGenerated(phase_2_patterns);
246 phase_2_patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>,
247 RemoveVolatileOps<kPreserveInputsAndOutputs>>(ctx);
248 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
249 }
250
runOnFunction()251 void PostQuantizeRemoveQDQPass::runOnFunction() {
252 OwningRewritePatternList patterns(&getContext());
253 auto func = getFunction();
254 auto* ctx = func.getContext();
255 TFL::populateWithGenerated(patterns);
256 patterns.insert<RemoveVolatileOps<kPreserveNone>>(ctx);
257 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
258 }
259
260 } // namespace
261
262 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops)263 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
264 bool emit_quant_adaptor_ops) {
265 return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
266 }
267
268 // Creates an instance of the TensorFlow Lite dialect PostQuantizeRemoveQDQ
269 // pass.
CreatePostQuantizeRemoveQDQPass()270 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizeRemoveQDQPass() {
271 return std::make_unique<PostQuantizeRemoveQDQPass>();
272 }
273
274 static PassRegistration<PostQuantizePass> pass;
275
276 static PassRegistration<PostQuantizeRemoveQDQPass> remove_qdq_pass;
277
278 } // namespace TFL
279 } // namespace mlir
280