• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/monitoring/counter.h"
49 
50 // NOLINTNEXTLINE
51 static llvm::cl::list<std::string> quantize_allowlist(
52     "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
53     llvm::cl::desc("comma separated list of allowlisted functions to be "
54                    "quantized. Only used in tests"),
55     llvm::cl::CommaSeparated);
56 
57 // NOLINTNEXTLINE
58 static llvm::cl::opt<bool> quantize_signed(
59     "tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
60     llvm::cl::desc("signed inference type. Only used in tests"),
61     llvm::cl::init(false));
62 
63 // NOLINTNEXTLINE
64 static llvm::cl::opt<bool> post_training_quantize(
65     "tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
66     llvm::cl::desc("enable post training quantization. Only used in tests"),
67     llvm::cl::init(false));
68 
69 // NOLINTNEXTLINE
70 static llvm::cl::opt<bool> legacy_float_scale(
71     "tfl-test-legacy-float-scale", llvm::cl::value_desc("bool"),
72     llvm::cl::desc("calculate quantization scales in float instead of double"),
73     llvm::cl::init(false));
74 
75 // NOLINTNEXTLINE
76 static llvm::cl::opt<bool> disable_per_channel(
77     "tfl-disable-per-channel", llvm::cl::value_desc("bool"),
78     llvm::cl::desc("Whether disable per-channel quantized weights."),
79     llvm::cl::init(false));
80 
81 //===----------------------------------------------------------------------===//
82 // The prepare-quantize Pass.
83 //
84 namespace mlir {
85 namespace TFL {
86 
87 namespace {
88 
89 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
90     "/tensorflow/lite/quantization/transforms/stats",
91     "The number of quantization pass invocations.", "path");
92 
93 // Applies prepare quantization on the model in TFL dialect. This pass runs
94 // before the quantization pass and propagate the quantization parameters
95 // across ops. This step is necessary for post-training quantization and also
96 // making the quantization rule for some operations in the quantization-aware
97 // training quantization simpler.
98 class PrepareQuantizePass
99     : public PassWrapper<PrepareQuantizePass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const100   void getDependentDialects(DialectRegistry& registry) const override {
101     registry
102         .insert<TensorFlowLiteDialect, ::mlir::quant::QuantizationDialect>();
103   }
104 
105  public:
106   // Constructor used by the PassRegistration and enforce uint8 quantization.
107   // This is only used by test.
PrepareQuantizePass()108   explicit PrepareQuantizePass() {
109     quant_specs_.inference_type =
110         quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
111     quant_specs_.post_training_quantization = post_training_quantize;
112     quant_specs_.legacy_float_scale = legacy_float_scale;
113   }
114 
115   // Constructor used by manually creating the pass.
PrepareQuantizePass(const QuantizationSpecs & quant_specs)116   explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
117       : quant_specs_(quant_specs) {}
118 
119   void runOnFunction() override;
120 
121  private:
122   // Set the quantization parameters of the input nodes. These parameters are
123   // converted from the user specified input value ranges. The input nodes with
124   // non-float tensor types will be skipped because they are not quantizable.
125   // Return true if number of input nodes doesn't equal to that of the input
126   // ranges.
127   bool SetInputNodesQuantizationParams(FuncOp func);
128 
129   // The function might contain more stats ops than required, and it will
130   // introduce requantize if the calibration stats have conflicts. This method
131   // tries to remove all the redundant stats ops.
132   bool RemoveRedundantStats(FuncOp func);
133 
134   // Verify the quantization specification is expected for quantizing the
135   // current function.
IsLegalQuantSpecs(FuncOp func)136   bool IsLegalQuantSpecs(FuncOp func) {
137     if (func.getName() == quant_specs_.target_func) {
138       return func.getNumArguments() == quant_specs_.input_ranges.size();
139     }
140     return true;
141   }
142 
143   // Get the min and max values from the quantization specification for the
144   // current function and argument index. Uses default values if the function
145   // is specified in the `quantize_allowlist`.
146   std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)147   GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
148     if (func_name == quant_specs_.target_func) {
149       return quant_specs_.input_ranges[index];
150     } else {
151       return {0.0, 255.0};
152     }
153   }
154 
155   // Apply some sanity check and report some warnings for those who don't follow
156   // the best quantization practice. This also fixes some simple violations.
157   void SanityCheckAndAdjustment(FuncOp func);
158 
159   // Whether the func contains Quantize ops. This is used to determine whether
160   // to use the quantization parameters from the fixed output range property.
161   bool ContainsQuantizeOps(FuncOp func);
162 
163   QuantizationSpecs quant_specs_;
164 };
165 
SetInputNodesQuantizationParams(FuncOp func)166 bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
167   StringRef func_name = func.getName();
168   auto& target_func = quant_specs_.target_func;
169 
170   // Skip this function because it isn't the target function from the spec or
171   // in the function while list.
172   if (target_func != func_name &&
173       !llvm::is_contained(quantize_allowlist, func_name)) {
174     return false;
175   }
176 
177   // If the validation fails, the pass should stop immediately.
178   if (!IsLegalQuantSpecs(func)) {
179     return true;
180   }
181 
182   OpBuilder builder(func);
183   bool is_signed = quant_specs_.IsSignedInferenceType();
184   IntegerAttr num_bits =
185       builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
186   BoolAttr narrow_range = builder.getBoolAttr(false);
187 
188   auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
189                              Block::iterator insertion_point, Value arg,
190                              int i) {
191     if (auto shaped = input_type.dyn_cast<ShapedType>()) {
192       if (shaped.getElementType().isa<FloatType>()) {
193         // If there are existing quantize ops, they are from training and we
194         // should respect them.
195         if (arg.hasOneUse() &&
196             llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
197           return;
198         }
199 
200         auto min_max = GetMinMaxValuesForArgument(func_name, i);
201         // The input min/max or mean/std are not specified, then skip.
202         if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
203 
204         TypeAttr params = quant::GetQuantizedTypeAttr(
205             builder, input_type,
206             builder.getF64FloatAttr(min_max.first.getValue()),
207             builder.getF64FloatAttr(min_max.second.getValue()),
208             /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
209         builder.setInsertionPoint(block, insertion_point);
210         auto q_op =
211             builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
212         auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
213                                                              q_op.getResult());
214         arg.replaceAllUsesWith(dq_op.getResult());
215         q_op.setOperand(arg);
216       }
217     }
218   };
219 
220   for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
221     BlockArgument arg = func.getArgument(i);
222     auto* arg_block = arg.getOwner();
223     add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
224                     std::next(arg_block->begin(), i), arg, i);
225   }
226 
227   return false;
228 }
229 
230 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
231 
RemoveRedundantStats(FuncOp func)232 bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
233   return RemoveRedundantStatsOps(func, GetOpQuantSpec);
234 }
235 
Quantized(Operation * user)236 static Value Quantized(Operation* user) {
237   if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
238     if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
239             *q.getResult().user_begin())) {
240       return dq.getResult();
241     }
242   }
243   return {};
244 }
245 
SanityCheckAndAdjustment(FuncOp func)246 void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
247   // If an op output has two users: one of them is a quantize op and another
248   // one is returned directly, we decide to return the quantized result instead,
249   // so this op can be quantized. This is only applied on the returned result
250   // because the error will not be accumulated.
251 
252   func.walk([&](ReturnOp ret) {
253     int i = 0;
254     for (Value returned : ret.operands()) {
255       llvm::SmallVector<Value, 4> quantized;
256       for (auto user : returned.getUsers()) {
257         if (auto q = Quantized(user)) {
258           quantized.push_back(q);
259         }
260       }
261       if (quantized.size() == 1) {
262         ret.setOperand(i, quantized.front());
263       }
264       i++;
265     }
266   });
267 
268   // We prefer to placing quantization emulation ops on the results of the
269   // concat ops.
270   func.walk([&](ConcatenationOp concat) {
271     if (concat.output().hasOneUse() &&
272         Quantized(*concat.output().user_begin())) {
273       return;
274     }
275     concat.emitWarning(
276         "Missing quantization parameter on the output might introduce "
277         "quantization error!");
278   });
279 
280   // Check for  (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
281   // eliminated at this point.  This only occurs for the pattern
282   //      (Quant (Dequant (Quant $in, $qB)), $qA)   $qB != $qA
283   // where the  qdq pair denotes a non-trivial requantization of an
284   // already quantized value. Since this makes little sense (directly quantizing
285   // (Quant $in, $qA) would introduce less quantization noise) the likely cause
286   // is an minor error in constructing the original network model that
287   // introduced back-to-back Fake Quantization operations. Hence: emit a
288   // warning. N.b. at this point we're (teporarility) in the quantization
289   // dialect (presumably enable re-use in xla etc) quant::*QuantizeCastOp
290   // we're matching here.
291   //
292   func.walk([&](quant::QuantizeCastOp q_op) {
293     // If up with end up with
294     auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
295         q_op.getOperand().getDefiningOp());
296     if (!dq_op) {
297       return;
298     }
299     auto dq_arg = dq_op.getOperand();
300 
301     if (!dq_arg.hasOneUse()) {
302       // The initial quantization is used someplace else ... so it might be
303       // reasonable for it to requantized for another purpose.
304       // Ideally would want to still check whether requantization narrows
305       // rather than widens the representation.
306       return;
307     }
308 
309     // Invariant:
310     // isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
311     // getdq_arg.getType() != q_op.getResult().getType()
312     //
313     // as otherwise qdq pair would have been optimized away.
314     auto qd_arg_def_q_op =
315         dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
316     if (!qd_arg_def_q_op) {
317       return;
318     }
319 
320     qd_arg_def_q_op.emitWarning()
321         << " quantizer's output has another quantizer (" << q_op.getLoc()
322         << ") as consumer - intentional?";
323   });
324 }
325 
ContainsQuantizeOps(FuncOp func)326 bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
327   for (const auto& op : func.getOps()) {
328     if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
329   }
330   return false;
331 }
332 
333 using PrepareQuantStats =
334     quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
335 
runOnFunction()336 void PrepareQuantizePass::runOnFunction() {
337   FuncOp func = getFunction();
338   MLIRContext* ctx = func.getContext();
339   ConvertTFLQuantOpsToMlirQuantOps(func);
340 
341   if (quant_specs_.post_training_quantization) {
342     tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
343     RemoveRedundantStats(func);
344   } else {
345     tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
346     // Set the quantization parameters for the quantizable input nodes. If this
347     // failed, return the function immediately. This is only required for
348     // quantization aware training model conversion.
349     if (SetInputNodesQuantizationParams(func)) {
350       return;
351     }
352   }
353 
354   bool is_signed = quant_specs_.IsSignedInferenceType();
355   int bit_width = quant_specs_.GetQuantizationTypeWidth();
356   // When this is true, the quantizer will try its best to extract the
357   // quantization parameters from the op quantization property and constant
358   // content. This is also set to true when the `quantize_allowlist` and
359   // `quantize_signed` test flags are enabled.
360   bool eager_quantize = ContainsQuantizeOps(func) ||
361                         (!quantize_allowlist.empty() || quantize_signed);
362   // Infer the tensor range for the activation ops and weight constants unless
363   // it is disabled explicitly.
364   bool infer_tensor_range =
365       (quant_specs_.post_training_quantization || eager_quantize) &&
366       !quant_specs_.disable_infer_tensor_range;
367 
368   // LSTM's restrict_scale requirement should be handled before converting stats
369   // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
370   // consistent. Otherwise some FileCheck tests would fail.
371   OwningRewritePatternList patterns_1;
372   if (quant_specs_.post_training_quantization) {
373     patterns_1.insert<PrepareLstmOutputScale<LSTMOp>>(ctx);
374     patterns_1.insert<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(
375         ctx);
376   }
377   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
378 
379   // During the legalization, unsigned quantized type is used, so we have to
380   // convert all of them to signed.
381   OwningRewritePatternList patterns_2;
382   if (is_signed) {
383     patterns_2.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(
384         ctx);
385     // Convert quant stats to int8 quantization parameters.
386     // Currently, only activation stats are imported, so narrow_range = false.
387     patterns_2.insert<PrepareQuantStats>(bit_width, false, true,
388                                          quant_specs_.legacy_float_scale, ctx);
389   } else {
390     // Convert quant stats to uint8 quantization parameters.
391     // Currently, only activation stats are imported, so narrow_range = false.
392     patterns_2.insert<PrepareQuantStats>(bit_width, false, false,
393                                          quant_specs_.legacy_float_scale, ctx);
394   }
395 
396   if (quant_specs_.post_training_quantization) {
397     patterns_2.insert<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
398     patterns_2.insert<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
399         ctx, quant_specs_);
400     patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
401   }
402   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
403 
404   SanityCheckAndAdjustment(func);
405 
406   // Finally, the quantization parameters can be propagated to the rest of the
407   // values (tensors).
408   ApplyQuantizationParamsPropagation(
409       func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
410       GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
411 
412   ConvertMlirQuantOpsToTFLQuantOps(func);
413 }
414 
415 }  // namespace
416 
417 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const QuantizationSpecs & quant_specs)418 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
419     const QuantizationSpecs& quant_specs) {
420   return std::make_unique<PrepareQuantizePass>(quant_specs);
421 }
422 
423 static PassRegistration<PrepareQuantizePass> pass(
424     "tfl-prepare-quantize", "Prepare TFL dialect for quantization");
425 
426 }  // namespace TFL
427 }  // namespace mlir
428