• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/monitoring/counter.h"
49 
50 // NOLINTNEXTLINE
51 static llvm::cl::list<std::string> quantize_allowlist(
52     "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
53     llvm::cl::desc("comma separated list of allowlisted functions to be "
54                    "quantized. Only used in tests"),
55     llvm::cl::CommaSeparated);
56 
57 // NOLINTNEXTLINE
58 static llvm::cl::opt<bool> quantize_signed(
59     "tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
60     llvm::cl::desc("signed inference type. Only used in tests"),
61     llvm::cl::init(false));
62 
63 // NOLINTNEXTLINE
64 static llvm::cl::opt<bool> post_training_quantize(
65     "tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
66     llvm::cl::desc("enable post training quantization. Only used in tests"),
67     llvm::cl::init(false));
68 
69 // NOLINTNEXTLINE
70 static llvm::cl::opt<bool> legacy_float_scale(
71     "tfl-test-legacy-float-scale", llvm::cl::value_desc("bool"),
72     llvm::cl::desc("calculate quantization scales in float instead of double"),
73     llvm::cl::init(false));
74 
75 // NOLINTNEXTLINE
76 static llvm::cl::opt<bool> disable_per_channel(
77     "tfl-disable-per-channel", llvm::cl::value_desc("bool"),
78     llvm::cl::desc("Whether disable per-channel quantized weights."),
79     llvm::cl::init(false));
80 
81 //===----------------------------------------------------------------------===//
82 // The prepare-quantize Pass.
83 //
84 namespace mlir {
85 namespace TFL {
86 
87 namespace {
88 
89 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
90     "/tensorflow/lite/quantization/transforms/stats",
91     "The number of quantization pass invocations.", "path");
92 
93 // Applies prepare quantization on the model in TFL dialect. This pass runs
94 // before the quantization pass and propagate the quantization parameters
95 // across ops. This step is necessary for post-training quantization and also
96 // making the quantization rule for some operations in the quantization-aware
97 // training quantization simpler.
98 class PrepareQuantizePass
99     : public PassWrapper<PrepareQuantizePass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const100   void getDependentDialects(DialectRegistry& registry) const override {
101     registry
102         .insert<TensorFlowLiteDialect, ::mlir::quant::QuantizationDialect>();
103   }
104 
105  public:
106   // Constructor used by the PassRegistration and enforce uint8 quantization.
107   // This is only used by test.
PrepareQuantizePass()108   explicit PrepareQuantizePass() {
109     quant_specs_.inference_type =
110         quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
111     quant_specs_.post_training_quantization = post_training_quantize;
112     quant_specs_.legacy_float_scale = legacy_float_scale;
113   }
114 
115   // Constructor used by manually creating the pass.
PrepareQuantizePass(const QuantizationSpecs & quant_specs)116   explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
117       : quant_specs_(quant_specs) {}
118 
getArgument() const119   StringRef getArgument() const final {
120     // This is the argument used to refer to the pass in
121     // the textual format (on the commandline for example).
122     return "tfl-prepare-quantize";
123   }
getDescription() const124   StringRef getDescription() const final {
125     // This is a brief description of the pass.
126     return "Prepare TFL dialect for quantization";
127   }
128 
129   void runOnFunction() override;
130 
131  private:
132   // Set the quantization parameters of the input nodes. These parameters are
133   // converted from the user specified input value ranges. The input nodes with
134   // non-float tensor types will be skipped because they are not quantizable.
135   // Return true if number of input nodes doesn't equal to that of the input
136   // ranges.
137   bool SetInputNodesQuantizationParams(FuncOp func);
138 
139   // The function might contain more stats ops than required, and it will
140   // introduce requantize if the calibration stats have conflicts. This method
141   // tries to remove all the redundant stats ops.
142   bool RemoveRedundantStats(FuncOp func);
143 
144   // Verify the quantization specification is expected for quantizing the
145   // current function.
IsLegalQuantSpecs(FuncOp func)146   bool IsLegalQuantSpecs(FuncOp func) {
147     if (func.getName() == quant_specs_.target_func) {
148       return func.getNumArguments() == quant_specs_.input_ranges.size();
149     }
150     return true;
151   }
152 
153   // Get the min and max values from the quantization specification for the
154   // current function and argument index. Uses default values if the function
155   // is specified in the `quantize_allowlist`.
156   std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)157   GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
158     if (func_name == quant_specs_.target_func) {
159       return quant_specs_.input_ranges[index];
160     } else {
161       return {0.0, 255.0};
162     }
163   }
164 
165   // Apply some sanity check and report some warnings for those who don't follow
166   // the best quantization practice. This also fixes some simple violations.
167   void SanityCheckAndAdjustment(FuncOp func);
168 
169   // Whether the func contains Quantize ops. This is used to determine whether
170   // to use the quantization parameters from the fixed output range property.
171   bool ContainsQuantizeOps(FuncOp func);
172 
173   QuantizationSpecs quant_specs_;
174 };
175 
SetInputNodesQuantizationParams(FuncOp func)176 bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
177   StringRef func_name = func.getName();
178   auto& target_func = quant_specs_.target_func;
179 
180   // Skip this function because it isn't the target function from the spec or
181   // in the function while list.
182   if (target_func != func_name &&
183       !llvm::is_contained(quantize_allowlist, func_name)) {
184     return false;
185   }
186 
187   // If the validation fails, the pass should stop immediately.
188   if (!IsLegalQuantSpecs(func)) {
189     return true;
190   }
191 
192   OpBuilder builder(func);
193   bool is_signed = quant_specs_.IsSignedInferenceType();
194   IntegerAttr num_bits =
195       builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
196   BoolAttr narrow_range = builder.getBoolAttr(false);
197 
198   auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
199                              Block::iterator insertion_point, Value arg,
200                              int i) {
201     if (auto shaped = input_type.dyn_cast<ShapedType>()) {
202       if (shaped.getElementType().isa<FloatType>()) {
203         // If there are existing quantize ops, they are from training and we
204         // should respect them.
205         if (arg.hasOneUse() &&
206             llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
207           return;
208         }
209 
210         auto min_max = GetMinMaxValuesForArgument(func_name, i);
211         // The input min/max or mean/std are not specified, then skip.
212         if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
213 
214         TypeAttr params = quant::GetQuantizedTypeAttr(
215             builder, input_type,
216             builder.getF64FloatAttr(min_max.first.getValue()),
217             builder.getF64FloatAttr(min_max.second.getValue()),
218             /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
219         builder.setInsertionPoint(block, insertion_point);
220         auto q_op =
221             builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
222         auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
223                                                              q_op.getResult());
224         arg.replaceAllUsesWith(dq_op.getResult());
225         q_op.setOperand(arg);
226       }
227     }
228   };
229 
230   for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
231     BlockArgument arg = func.getArgument(i);
232     auto* arg_block = arg.getOwner();
233     add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
234                     std::next(arg_block->begin(), i), arg, i);
235   }
236 
237   return false;
238 }
239 
240 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
241 
RemoveRedundantStats(FuncOp func)242 bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
243   return RemoveRedundantStatsOps(func, GetOpQuantSpec);
244 }
245 
Quantized(Operation * user)246 static Value Quantized(Operation* user) {
247   if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
248     if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
249             *q.getResult().user_begin())) {
250       return dq.getResult();
251     }
252   }
253   return {};
254 }
255 
SanityCheckAndAdjustment(FuncOp func)256 void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
257   // If an op output has two users: one of them is a quantize op and another
258   // one is returned directly, we decide to return the quantized result instead,
259   // so this op can be quantized. This is only applied on the returned result
260   // because the error will not be accumulated.
261 
262   func.walk([&](ReturnOp ret) {
263     int i = 0;
264     for (Value returned : ret.operands()) {
265       llvm::SmallVector<Value, 4> quantized;
266       for (auto user : returned.getUsers()) {
267         if (auto q = Quantized(user)) {
268           quantized.push_back(q);
269         }
270       }
271       if (quantized.size() == 1) {
272         ret.setOperand(i, quantized.front());
273       }
274       i++;
275     }
276   });
277 
278   // We prefer to placing quantization emulation ops on the results of the
279   // concat ops.
280   func.walk([&](ConcatenationOp concat) {
281     if (concat.output().hasOneUse() &&
282         Quantized(*concat.output().user_begin())) {
283       return;
284     }
285     concat.emitWarning(
286         "Missing quantization parameter on the output might introduce "
287         "quantization error!");
288   });
289 
290   // Check for  (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
291   // eliminated at this point.  This only occurs for the pattern
292   //      (Quant (Dequant (Quant $in, $qB)), $qA)   $qB != $qA
293   // where the  qdq pair denotes a non-trivial requantization of an
294   // already quantized value. Since this makes little sense (directly quantizing
295   // (Quant $in, $qA) would introduce less quantization noise) the likely cause
296   // is an minor error in constructing the original network model that
297   // introduced back-to-back Fake Quantization operations. Hence: emit a
298   // warning. N.b. at this point we're (teporarility) in the quantization
299   // dialect (presumably enable re-use in xla etc) quant::*QuantizeCastOp
300   // we're matching here.
301   //
302   func.walk([&](quant::QuantizeCastOp q_op) {
303     // If up with end up with
304     auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
305         q_op.getOperand().getDefiningOp());
306     if (!dq_op) {
307       return;
308     }
309     auto dq_arg = dq_op.getOperand();
310 
311     if (!dq_arg.hasOneUse()) {
312       // The initial quantization is used someplace else ... so it might be
313       // reasonable for it to requantized for another purpose.
314       // Ideally would want to still check whether requantization narrows
315       // rather than widens the representation.
316       return;
317     }
318 
319     // Invariant:
320     // isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
321     // getdq_arg.getType() != q_op.getResult().getType()
322     //
323     // as otherwise qdq pair would have been optimized away.
324     auto qd_arg_def_q_op =
325         dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
326     if (!qd_arg_def_q_op) {
327       return;
328     }
329 
330     qd_arg_def_q_op.emitWarning()
331         << " quantizer's output has another quantizer (" << q_op.getLoc()
332         << ") as consumer - intentional?";
333   });
334 }
335 
ContainsQuantizeOps(FuncOp func)336 bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
337   for (const auto& op : func.getOps()) {
338     if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
339   }
340   return false;
341 }
342 
343 using PrepareQuantStats =
344     quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
345 
runOnFunction()346 void PrepareQuantizePass::runOnFunction() {
347   FuncOp func = getFunction();
348   MLIRContext* ctx = func.getContext();
349   ConvertTFLQuantOpsToMlirQuantOps(func);
350 
351   if (quant_specs_.post_training_quantization) {
352     tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
353     RemoveRedundantStats(func);
354   } else {
355     tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
356     // Set the quantization parameters for the quantizable input nodes. If this
357     // failed, return the function immediately. This is only required for
358     // quantization aware training model conversion.
359     if (SetInputNodesQuantizationParams(func)) {
360       return;
361     }
362   }
363 
364   bool is_signed = quant_specs_.IsSignedInferenceType();
365   int bit_width = quant_specs_.GetQuantizationTypeWidth();
366   // When this is true, the quantizer will try its best to extract the
367   // quantization parameters from the op quantization property and constant
368   // content. This is also set to true when the `quantize_allowlist` and
369   // `quantize_signed` test flags are enabled.
370   bool eager_quantize = ContainsQuantizeOps(func) ||
371                         (!quantize_allowlist.empty() || quantize_signed);
372   // Infer the tensor range for the activation ops and weight constants unless
373   // it is disabled explicitly.
374   bool infer_tensor_range =
375       (quant_specs_.post_training_quantization || eager_quantize) &&
376       !quant_specs_.disable_infer_tensor_range;
377 
378   // LSTM's restrict_scale requirement should be handled before converting stats
379   // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
380   // consistent. Otherwise some FileCheck tests would fail.
381   OwningRewritePatternList patterns_1(&getContext());
382   if (quant_specs_.post_training_quantization) {
383     patterns_1.insert<PrepareLstmOutputScale<LSTMOp>>(ctx);
384     patterns_1.insert<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(
385         ctx);
386   }
387   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
388 
389   // During the legalization, unsigned quantized type is used, so we have to
390   // convert all of them to signed.
391   OwningRewritePatternList patterns_2(&getContext());
392   if (is_signed) {
393     patterns_2.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(
394         ctx);
395     // Convert quant stats to int8 quantization parameters.
396     // Currently, only activation stats are imported, so narrow_range = false.
397     patterns_2.insert<PrepareQuantStats>(bit_width, false, true,
398                                          quant_specs_.legacy_float_scale, ctx);
399   } else {
400     // Convert quant stats to uint8 quantization parameters.
401     // Currently, only activation stats are imported, so narrow_range = false.
402     patterns_2.insert<PrepareQuantStats>(bit_width, false, false,
403                                          quant_specs_.legacy_float_scale, ctx);
404   }
405 
406   if (quant_specs_.post_training_quantization) {
407     patterns_2.insert<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
408     patterns_2.insert<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
409         ctx, quant_specs_);
410     patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
411   }
412   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
413 
414   SanityCheckAndAdjustment(func);
415 
416   // Finally, the quantization parameters can be propagated to the rest of the
417   // values (tensors).
418   ApplyQuantizationParamsPropagation(
419       func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
420       GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
421 
422   ConvertMlirQuantOpsToTFLQuantOps(func);
423 }
424 
425 }  // namespace
426 
427 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const QuantizationSpecs & quant_specs)428 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
429     const QuantizationSpecs& quant_specs) {
430   return std::make_unique<PrepareQuantizePass>(quant_specs);
431 }
432 
433 static PassRegistration<PrepareQuantizePass> pass;
434 
435 }  // namespace TFL
436 }  // namespace mlir
437