1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/monitoring/counter.h"
49
50 // NOLINTNEXTLINE
51 static llvm::cl::list<std::string> quantize_allowlist(
52 "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
53 llvm::cl::desc("comma separated list of allowlisted functions to be "
54 "quantized. Only used in tests"),
55 llvm::cl::CommaSeparated);
56
57 // NOLINTNEXTLINE
58 static llvm::cl::opt<bool> quantize_signed(
59 "tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
60 llvm::cl::desc("signed inference type. Only used in tests"),
61 llvm::cl::init(false));
62
63 // NOLINTNEXTLINE
64 static llvm::cl::opt<bool> post_training_quantize(
65 "tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
66 llvm::cl::desc("enable post training quantization. Only used in tests"),
67 llvm::cl::init(false));
68
69 // NOLINTNEXTLINE
70 static llvm::cl::opt<bool> legacy_float_scale(
71 "tfl-test-legacy-float-scale", llvm::cl::value_desc("bool"),
72 llvm::cl::desc("calculate quantization scales in float instead of double"),
73 llvm::cl::init(false));
74
75 // NOLINTNEXTLINE
76 static llvm::cl::opt<bool> disable_per_channel(
77 "tfl-disable-per-channel", llvm::cl::value_desc("bool"),
78 llvm::cl::desc("Whether disable per-channel quantized weights."),
79 llvm::cl::init(false));
80
81 //===----------------------------------------------------------------------===//
82 // The prepare-quantize Pass.
83 //
84 namespace mlir {
85 namespace TFL {
86
87 namespace {
88
89 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
90 "/tensorflow/lite/quantization/transforms/stats",
91 "The number of quantization pass invocations.", "path");
92
93 // Applies prepare quantization on the model in TFL dialect. This pass runs
94 // before the quantization pass and propagate the quantization parameters
95 // across ops. This step is necessary for post-training quantization and also
96 // making the quantization rule for some operations in the quantization-aware
97 // training quantization simpler.
98 class PrepareQuantizePass
99 : public PassWrapper<PrepareQuantizePass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const100 void getDependentDialects(DialectRegistry& registry) const override {
101 registry
102 .insert<TensorFlowLiteDialect, ::mlir::quant::QuantizationDialect>();
103 }
104
105 public:
106 // Constructor used by the PassRegistration and enforce uint8 quantization.
107 // This is only used by test.
PrepareQuantizePass()108 explicit PrepareQuantizePass() {
109 quant_specs_.inference_type =
110 quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
111 quant_specs_.post_training_quantization = post_training_quantize;
112 quant_specs_.legacy_float_scale = legacy_float_scale;
113 }
114
115 // Constructor used by manually creating the pass.
PrepareQuantizePass(const QuantizationSpecs & quant_specs)116 explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
117 : quant_specs_(quant_specs) {}
118
getArgument() const119 StringRef getArgument() const final {
120 // This is the argument used to refer to the pass in
121 // the textual format (on the commandline for example).
122 return "tfl-prepare-quantize";
123 }
getDescription() const124 StringRef getDescription() const final {
125 // This is a brief description of the pass.
126 return "Prepare TFL dialect for quantization";
127 }
128
129 void runOnFunction() override;
130
131 private:
132 // Set the quantization parameters of the input nodes. These parameters are
133 // converted from the user specified input value ranges. The input nodes with
134 // non-float tensor types will be skipped because they are not quantizable.
135 // Return true if number of input nodes doesn't equal to that of the input
136 // ranges.
137 bool SetInputNodesQuantizationParams(FuncOp func);
138
139 // The function might contain more stats ops than required, and it will
140 // introduce requantize if the calibration stats have conflicts. This method
141 // tries to remove all the redundant stats ops.
142 bool RemoveRedundantStats(FuncOp func);
143
144 // Verify the quantization specification is expected for quantizing the
145 // current function.
IsLegalQuantSpecs(FuncOp func)146 bool IsLegalQuantSpecs(FuncOp func) {
147 if (func.getName() == quant_specs_.target_func) {
148 return func.getNumArguments() == quant_specs_.input_ranges.size();
149 }
150 return true;
151 }
152
153 // Get the min and max values from the quantization specification for the
154 // current function and argument index. Uses default values if the function
155 // is specified in the `quantize_allowlist`.
156 std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)157 GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
158 if (func_name == quant_specs_.target_func) {
159 return quant_specs_.input_ranges[index];
160 } else {
161 return {0.0, 255.0};
162 }
163 }
164
165 // Apply some sanity check and report some warnings for those who don't follow
166 // the best quantization practice. This also fixes some simple violations.
167 void SanityCheckAndAdjustment(FuncOp func);
168
169 // Whether the func contains Quantize ops. This is used to determine whether
170 // to use the quantization parameters from the fixed output range property.
171 bool ContainsQuantizeOps(FuncOp func);
172
173 QuantizationSpecs quant_specs_;
174 };
175
SetInputNodesQuantizationParams(FuncOp func)176 bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
177 StringRef func_name = func.getName();
178 auto& target_func = quant_specs_.target_func;
179
180 // Skip this function because it isn't the target function from the spec or
181 // in the function while list.
182 if (target_func != func_name &&
183 !llvm::is_contained(quantize_allowlist, func_name)) {
184 return false;
185 }
186
187 // If the validation fails, the pass should stop immediately.
188 if (!IsLegalQuantSpecs(func)) {
189 return true;
190 }
191
192 OpBuilder builder(func);
193 bool is_signed = quant_specs_.IsSignedInferenceType();
194 IntegerAttr num_bits =
195 builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
196 BoolAttr narrow_range = builder.getBoolAttr(false);
197
198 auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
199 Block::iterator insertion_point, Value arg,
200 int i) {
201 if (auto shaped = input_type.dyn_cast<ShapedType>()) {
202 if (shaped.getElementType().isa<FloatType>()) {
203 // If there are existing quantize ops, they are from training and we
204 // should respect them.
205 if (arg.hasOneUse() &&
206 llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
207 return;
208 }
209
210 auto min_max = GetMinMaxValuesForArgument(func_name, i);
211 // The input min/max or mean/std are not specified, then skip.
212 if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
213
214 TypeAttr params = quant::GetQuantizedTypeAttr(
215 builder, input_type,
216 builder.getF64FloatAttr(min_max.first.getValue()),
217 builder.getF64FloatAttr(min_max.second.getValue()),
218 /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
219 builder.setInsertionPoint(block, insertion_point);
220 auto q_op =
221 builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
222 auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
223 q_op.getResult());
224 arg.replaceAllUsesWith(dq_op.getResult());
225 q_op.setOperand(arg);
226 }
227 }
228 };
229
230 for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
231 BlockArgument arg = func.getArgument(i);
232 auto* arg_block = arg.getOwner();
233 add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
234 std::next(arg_block->begin(), i), arg, i);
235 }
236
237 return false;
238 }
239
240 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
241
RemoveRedundantStats(FuncOp func)242 bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
243 return RemoveRedundantStatsOps(func, GetOpQuantSpec);
244 }
245
Quantized(Operation * user)246 static Value Quantized(Operation* user) {
247 if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
248 if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
249 *q.getResult().user_begin())) {
250 return dq.getResult();
251 }
252 }
253 return {};
254 }
255
SanityCheckAndAdjustment(FuncOp func)256 void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
257 // If an op output has two users: one of them is a quantize op and another
258 // one is returned directly, we decide to return the quantized result instead,
259 // so this op can be quantized. This is only applied on the returned result
260 // because the error will not be accumulated.
261
262 func.walk([&](ReturnOp ret) {
263 int i = 0;
264 for (Value returned : ret.operands()) {
265 llvm::SmallVector<Value, 4> quantized;
266 for (auto user : returned.getUsers()) {
267 if (auto q = Quantized(user)) {
268 quantized.push_back(q);
269 }
270 }
271 if (quantized.size() == 1) {
272 ret.setOperand(i, quantized.front());
273 }
274 i++;
275 }
276 });
277
278 // We prefer to placing quantization emulation ops on the results of the
279 // concat ops.
280 func.walk([&](ConcatenationOp concat) {
281 if (concat.output().hasOneUse() &&
282 Quantized(*concat.output().user_begin())) {
283 return;
284 }
285 concat.emitWarning(
286 "Missing quantization parameter on the output might introduce "
287 "quantization error!");
288 });
289
290 // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
291 // eliminated at this point. This only occurs for the pattern
292 // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA
293 // where the qdq pair denotes a non-trivial requantization of an
294 // already quantized value. Since this makes little sense (directly quantizing
295 // (Quant $in, $qA) would introduce less quantization noise) the likely cause
296 // is an minor error in constructing the original network model that
297 // introduced back-to-back Fake Quantization operations. Hence: emit a
298 // warning. N.b. at this point we're (teporarility) in the quantization
299 // dialect (presumably enable re-use in xla etc) quant::*QuantizeCastOp
300 // we're matching here.
301 //
302 func.walk([&](quant::QuantizeCastOp q_op) {
303 // If up with end up with
304 auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
305 q_op.getOperand().getDefiningOp());
306 if (!dq_op) {
307 return;
308 }
309 auto dq_arg = dq_op.getOperand();
310
311 if (!dq_arg.hasOneUse()) {
312 // The initial quantization is used someplace else ... so it might be
313 // reasonable for it to requantized for another purpose.
314 // Ideally would want to still check whether requantization narrows
315 // rather than widens the representation.
316 return;
317 }
318
319 // Invariant:
320 // isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
321 // getdq_arg.getType() != q_op.getResult().getType()
322 //
323 // as otherwise qdq pair would have been optimized away.
324 auto qd_arg_def_q_op =
325 dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
326 if (!qd_arg_def_q_op) {
327 return;
328 }
329
330 qd_arg_def_q_op.emitWarning()
331 << " quantizer's output has another quantizer (" << q_op.getLoc()
332 << ") as consumer - intentional?";
333 });
334 }
335
ContainsQuantizeOps(FuncOp func)336 bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
337 for (const auto& op : func.getOps()) {
338 if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
339 }
340 return false;
341 }
342
343 using PrepareQuantStats =
344 quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
345
runOnFunction()346 void PrepareQuantizePass::runOnFunction() {
347 FuncOp func = getFunction();
348 MLIRContext* ctx = func.getContext();
349 ConvertTFLQuantOpsToMlirQuantOps(func);
350
351 if (quant_specs_.post_training_quantization) {
352 tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
353 RemoveRedundantStats(func);
354 } else {
355 tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
356 // Set the quantization parameters for the quantizable input nodes. If this
357 // failed, return the function immediately. This is only required for
358 // quantization aware training model conversion.
359 if (SetInputNodesQuantizationParams(func)) {
360 return;
361 }
362 }
363
364 bool is_signed = quant_specs_.IsSignedInferenceType();
365 int bit_width = quant_specs_.GetQuantizationTypeWidth();
366 // When this is true, the quantizer will try its best to extract the
367 // quantization parameters from the op quantization property and constant
368 // content. This is also set to true when the `quantize_allowlist` and
369 // `quantize_signed` test flags are enabled.
370 bool eager_quantize = ContainsQuantizeOps(func) ||
371 (!quantize_allowlist.empty() || quantize_signed);
372 // Infer the tensor range for the activation ops and weight constants unless
373 // it is disabled explicitly.
374 bool infer_tensor_range =
375 (quant_specs_.post_training_quantization || eager_quantize) &&
376 !quant_specs_.disable_infer_tensor_range;
377
378 // LSTM's restrict_scale requirement should be handled before converting stats
379 // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
380 // consistent. Otherwise some FileCheck tests would fail.
381 OwningRewritePatternList patterns_1(&getContext());
382 if (quant_specs_.post_training_quantization) {
383 patterns_1.insert<PrepareLstmOutputScale<LSTMOp>>(ctx);
384 patterns_1.insert<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(
385 ctx);
386 }
387 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
388
389 // During the legalization, unsigned quantized type is used, so we have to
390 // convert all of them to signed.
391 OwningRewritePatternList patterns_2(&getContext());
392 if (is_signed) {
393 patterns_2.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(
394 ctx);
395 // Convert quant stats to int8 quantization parameters.
396 // Currently, only activation stats are imported, so narrow_range = false.
397 patterns_2.insert<PrepareQuantStats>(bit_width, false, true,
398 quant_specs_.legacy_float_scale, ctx);
399 } else {
400 // Convert quant stats to uint8 quantization parameters.
401 // Currently, only activation stats are imported, so narrow_range = false.
402 patterns_2.insert<PrepareQuantStats>(bit_width, false, false,
403 quant_specs_.legacy_float_scale, ctx);
404 }
405
406 if (quant_specs_.post_training_quantization) {
407 patterns_2.insert<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
408 patterns_2.insert<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
409 ctx, quant_specs_);
410 patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
411 }
412 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
413
414 SanityCheckAndAdjustment(func);
415
416 // Finally, the quantization parameters can be propagated to the rest of the
417 // values (tensors).
418 ApplyQuantizationParamsPropagation(
419 func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
420 GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
421
422 ConvertMlirQuantOpsToTFLQuantOps(func);
423 }
424
425 } // namespace
426
427 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const QuantizationSpecs & quant_specs)428 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
429 const QuantizationSpecs& quant_specs) {
430 return std::make_unique<PrepareQuantizePass>(quant_specs);
431 }
432
433 static PassRegistration<PrepareQuantizePass> pass;
434
435 } // namespace TFL
436 } // namespace mlir
437