• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies some clean up steps after quantization.
17 
18 #include <utility>
19 
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
27 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
28 
29 //===----------------------------------------------------------------------===//
30 // The post-quantize Passes.
31 //
32 namespace mlir {
33 namespace TFL {
34 namespace {
35 
36 // Applies all the clean up steps after quantization.
37 class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
38  public:
39   // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()40   explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
41 
42   // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops)43   explicit PostQuantizePass(bool emit_quant_adaptor_ops)
44       : emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
45 
getArgument() const46   StringRef getArgument() const final {
47     // This is the argument used to refer to the pass in
48     // the textual format (on the commandline for example).
49     return "tfl-post-quantize";
50   }
getDescription() const51   StringRef getDescription() const final {
52     // This is a brief description of the pass.
53     return "Apply post quantization clean up after quantization";
54   }
55 
56   void runOnFunction() override;
57 
58  private:
59   // Set this flag to true if the inputs and outputs are in floating point. The
60   // quant adaptor ops convert them to fixed point values (i.e. quantize) before
61   // feeding them to the model and convert them back to floating point
62   // (i.e. dequantize) as the output.
63   bool emit_quant_adaptor_ops_;
64 };
65 
66 // Cleans up unnecessary QDQ pattern for input/output ops.
67 class PostQuantizeRemoveQDQPass
68     : public PassWrapper<PostQuantizeRemoveQDQPass, FunctionPass> {
69  public:
70   // Constructor used by the PassRegistration. This will remove QDQ ops.
PostQuantizeRemoveQDQPass()71   explicit PostQuantizeRemoveQDQPass() {}
72 
getArgument() const73   StringRef getArgument() const final {
74     // This is the argument used to refer to the pass in
75     // the textual format (on the commandline for example).
76     return "tfl-post-quantize-remove-qdq";
77   }
getDescription() const78   StringRef getDescription() const final {
79     // This is a brief description of the pass.
80     return "Remove qdq from input and output nodes after quantization";
81   }
82 
83   void runOnFunction() override;
84 };
85 
86 // TODO(fengliuai): migrate to use modify_io_nodes pass.
RemoveQuantizationAdaptorOps(FuncOp func)87 void RemoveQuantizationAdaptorOps(FuncOp func) {
88   mlir::OpBuilder builder(func.getBody());
89   auto& bb = func.front();
90 
91   int num_args = bb.getNumArguments();
92   llvm::SmallVector<Type, 4> input_types;
93   input_types.reserve(num_args);
94   // Edit the block arguments and create the new input ops in place to replace
95   // the old input ops and quantize ops.
96   for (int i = 0; i != num_args; ++i) {
97     // Previous loop iteration may invalidate the insertion point so we have to
98     // reset insertion point each iteration.
99     builder.setInsertionPointToStart(&bb);
100 
101     // In each iteration, a new argument is appended to the end of the list
102     // and the current argument is erased, so here we always process the first
103     // argument in the list.
104     auto arg = bb.getArgument(0);
105 
106     auto remove_quantize_op = [&](QuantizeOp quantize_op) {
107       auto quantize_output = quantize_op.output();
108       auto quantize_type = quantize_output.getType();
109       input_types.push_back(quantize_type);
110       auto new_arg = bb.addArgument(quantize_type);
111       quantize_output.replaceAllUsesWith(new_arg);
112       quantize_op.erase();
113       arg.dropAllUses();
114       bb.eraseArgument(0);
115     };
116 
117     // This is looking for a pattern: arg -> tfl.quantize
118     if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
119       auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
120       remove_quantize_op(quantize_op);
121       continue;
122     }
123 
124     // Make a copy of current argument and append it to the end of the list if
125     // the pattern isn't found.
126     Type arg_type = arg.getType();
127     input_types.push_back(arg_type);
128     auto new_arg = bb.addArgument(arg_type);
129     arg.replaceAllUsesWith(new_arg);
130     arg.dropAllUses();
131     bb.eraseArgument(0);
132   }
133 
134   // Edit the return ops and remove the dequantize ops in place.
135   auto* terminator = bb.getTerminator();
136   int num_return_operands = terminator->getNumOperands();
137   llvm::SmallVector<Type, 4> output_types;
138   output_types.reserve(num_return_operands);
139   for (int i = 0; i != num_return_operands; ++i) {
140     auto returned_value = terminator->getOperand(i);
141     Operation* returned_op = returned_value.getDefiningOp();
142     if (returned_op && returned_op->hasOneUse() &&
143         llvm::isa<DequantizeOp>(returned_op)) {
144       auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
145       Value dequantized_result = dequantize_op.input();
146       output_types.push_back(dequantized_result.getType());
147       terminator->setOperand(i, dequantized_result);
148       returned_op->erase();
149     } else {
150       output_types.push_back(returned_value.getType());
151     }
152   }
153   auto new_func_type = builder.getFunctionType(input_types, output_types);
154   func.setType(new_func_type);
155 }
156 
157 enum RemoveVolatileOpsType {
158   // Remove all volatile quant-dequant ops.
159   kPreserveNone,
160   // Preserve volatile quant-dequants for input and output ops.
161   kPreserveInputsAndOutputs,
162 };
163 
164 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
165 template <RemoveVolatileOpsType remove_volatile_ops_type>
166 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anonbfc12d730111::RemoveVolatileOps167   explicit RemoveVolatileOps(MLIRContext* context)
168       : OpRewritePattern<DequantizeOp>(context, 1) {}
169 
matchAndRewritemlir::TFL::__anonbfc12d730111::RemoveVolatileOps170   LogicalResult matchAndRewrite(DequantizeOp op,
171                                 PatternRewriter& rewriter) const override {
172     auto input_op = op.input().getDefiningOp();
173     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
174       if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
175 
176       if (remove_volatile_ops_type == kPreserveInputsAndOutputs) {
177         // Don't remove leading and tailing QDQ for PQT workflow, so the io
178         // modifying lib can work correctly.
179         if (!q.input().getDefiningOp()) return failure();
180         if (op->hasOneUse() &&
181             op->user_begin()->hasTrait<OpTrait::IsTerminator>())
182           return failure();
183       }
184       // If the quantize op is a requantize op, it is being used in other scale
185       // adjustments and should be kept. Instead, moving dequantize op before
186       // the requantize op to remove the unnecessary requantize op.
187       if (auto qtype = quant::QuantizedType::getQuantizedElementType(
188               q.input().getType())) {
189         rewriter.setInsertionPoint(op);
190         rewriter.replaceOpWithNewOp<DequantizeOp>(op, op.output().getType(),
191                                                   q.input());
192         return success();
193       }
194 
195       op.replaceAllUsesWith(q.input());
196       return success();
197     }
198     return failure();
199   }
200 };
201 
202 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
203 // output.
204 template <typename OpTy>
205 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
206  public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anonbfc12d730111::PruneUnusedOpsWithSideEffect207   explicit PruneUnusedOpsWithSideEffect(MLIRContext* context)
208       : OpRewritePattern<OpTy>(context) {}
209 
matchAndRewritemlir::TFL::__anonbfc12d730111::PruneUnusedOpsWithSideEffect210   LogicalResult matchAndRewrite(OpTy op,
211                                 PatternRewriter& rewriter) const override {
212     if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
213       return failure();
214     }
215     for (auto result : op.getOperation()->getOpResults()) {
216       if (!result.use_empty()) {
217         return failure();
218       }
219     }
220     rewriter.eraseOp(op);
221     return success();
222   }
223 };
224 
225 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
226 
runOnFunction()227 void PostQuantizePass::runOnFunction() {
228   OwningRewritePatternList patterns(&getContext());
229   auto func = getFunction();
230   auto* ctx = func.getContext();
231   TFL::populateWithGenerated(patterns);
232   patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
233   patterns.insert<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
234   patterns
235       .insert<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
236           ctx);
237   patterns.insert<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
238   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
239 
240   if (!emit_quant_adaptor_ops_) {
241     RemoveQuantizationAdaptorOps(getFunction());
242   }
243 
244   OwningRewritePatternList phase_2_patterns(&getContext());
245   TFL::populateWithGenerated(phase_2_patterns);
246   phase_2_patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>,
247                           RemoveVolatileOps<kPreserveInputsAndOutputs>>(ctx);
248   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
249 }
250 
runOnFunction()251 void PostQuantizeRemoveQDQPass::runOnFunction() {
252   OwningRewritePatternList patterns(&getContext());
253   auto func = getFunction();
254   auto* ctx = func.getContext();
255   TFL::populateWithGenerated(patterns);
256   patterns.insert<RemoveVolatileOps<kPreserveNone>>(ctx);
257   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
258 }
259 
260 }  // namespace
261 
262 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops)263 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
264     bool emit_quant_adaptor_ops) {
265   return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
266 }
267 
268 // Creates an instance of the TensorFlow Lite dialect PostQuantizeRemoveQDQ
269 // pass.
CreatePostQuantizeRemoveQDQPass()270 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizeRemoveQDQPass() {
271   return std::make_unique<PostQuantizeRemoveQDQPass>();
272 }
273 
274 static PassRegistration<PostQuantizePass> pass;
275 
276 static PassRegistration<PostQuantizeRemoveQDQPass> remove_qdq_pass;
277 
278 }  // namespace TFL
279 }  // namespace mlir
280