1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
46 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/monitoring/counter.h"
49
50 // NOLINTNEXTLINE
51 static llvm::cl::list<std::string> quantize_allowlist(
52 "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
53 llvm::cl::desc("comma separated list of allowlisted functions to be "
54 "quantized. Only used in tests"),
55 llvm::cl::CommaSeparated);
56
57 // NOLINTNEXTLINE
58 static llvm::cl::opt<bool> quantize_signed(
59 "tfl-test-quantize-signed", llvm::cl::value_desc("bool"),
60 llvm::cl::desc("signed inference type. Only used in tests"),
61 llvm::cl::init(false));
62
63 // NOLINTNEXTLINE
64 static llvm::cl::opt<bool> post_training_quantize(
65 "tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
66 llvm::cl::desc("enable post training quantization. Only used in tests"),
67 llvm::cl::init(false));
68
69 // NOLINTNEXTLINE
70 static llvm::cl::opt<bool> legacy_float_scale(
71 "tfl-test-legacy-float-scale", llvm::cl::value_desc("bool"),
72 llvm::cl::desc("calculate quantization scales in float instead of double"),
73 llvm::cl::init(false));
74
75 // NOLINTNEXTLINE
76 static llvm::cl::opt<bool> disable_per_channel(
77 "tfl-disable-per-channel", llvm::cl::value_desc("bool"),
78 llvm::cl::desc("Whether disable per-channel quantized weights."),
79 llvm::cl::init(false));
80
81 //===----------------------------------------------------------------------===//
82 // The prepare-quantize Pass.
83 //
84 namespace mlir {
85 namespace TFL {
86
87 namespace {
88
89 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
90 "/tensorflow/lite/quantization/transforms/stats",
91 "The number of quantization pass invocations.", "path");
92
93 // Applies prepare quantization on the model in TFL dialect. This pass runs
94 // before the quantization pass and propagate the quantization parameters
95 // across ops. This step is necessary for post-training quantization and also
96 // making the quantization rule for some operations in the quantization-aware
97 // training quantization simpler.
98 class PrepareQuantizePass
99 : public PassWrapper<PrepareQuantizePass, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const100 void getDependentDialects(DialectRegistry& registry) const override {
101 registry
102 .insert<TensorFlowLiteDialect, ::mlir::quant::QuantizationDialect>();
103 }
104
105 public:
106 // Constructor used by the PassRegistration and enforce uint8 quantization.
107 // This is only used by test.
PrepareQuantizePass()108 explicit PrepareQuantizePass() {
109 quant_specs_.inference_type =
110 quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
111 quant_specs_.post_training_quantization = post_training_quantize;
112 quant_specs_.legacy_float_scale = legacy_float_scale;
113 }
114
115 // Constructor used by manually creating the pass.
PrepareQuantizePass(const QuantizationSpecs & quant_specs)116 explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs)
117 : quant_specs_(quant_specs) {}
118
119 void runOnFunction() override;
120
121 private:
122 // Set the quantization parameters of the input nodes. These parameters are
123 // converted from the user specified input value ranges. The input nodes with
124 // non-float tensor types will be skipped because they are not quantizable.
125 // Return true if number of input nodes doesn't equal to that of the input
126 // ranges.
127 bool SetInputNodesQuantizationParams(FuncOp func);
128
129 // The function might contain more stats ops than required, and it will
130 // introduce requantize if the calibration stats have conflicts. This method
131 // tries to remove all the redundant stats ops.
132 bool RemoveRedundantStats(FuncOp func);
133
134 // Verify the quantization specification is expected for quantizing the
135 // current function.
IsLegalQuantSpecs(FuncOp func)136 bool IsLegalQuantSpecs(FuncOp func) {
137 if (func.getName() == quant_specs_.target_func) {
138 return func.getNumArguments() == quant_specs_.input_ranges.size();
139 }
140 return true;
141 }
142
143 // Get the min and max values from the quantization specification for the
144 // current function and argument index. Uses default values if the function
145 // is specified in the `quantize_allowlist`.
146 std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)147 GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
148 if (func_name == quant_specs_.target_func) {
149 return quant_specs_.input_ranges[index];
150 } else {
151 return {0.0, 255.0};
152 }
153 }
154
155 // Apply some sanity check and report some warnings for those who don't follow
156 // the best quantization practice. This also fixes some simple violations.
157 void SanityCheckAndAdjustment(FuncOp func);
158
159 // Whether the func contains Quantize ops. This is used to determine whether
160 // to use the quantization parameters from the fixed output range property.
161 bool ContainsQuantizeOps(FuncOp func);
162
163 QuantizationSpecs quant_specs_;
164 };
165
SetInputNodesQuantizationParams(FuncOp func)166 bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
167 StringRef func_name = func.getName();
168 auto& target_func = quant_specs_.target_func;
169
170 // Skip this function because it isn't the target function from the spec or
171 // in the function while list.
172 if (target_func != func_name &&
173 !llvm::is_contained(quantize_allowlist, func_name)) {
174 return false;
175 }
176
177 // If the validation fails, the pass should stop immediately.
178 if (!IsLegalQuantSpecs(func)) {
179 return true;
180 }
181
182 OpBuilder builder(func);
183 bool is_signed = quant_specs_.IsSignedInferenceType();
184 IntegerAttr num_bits =
185 builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
186 BoolAttr narrow_range = builder.getBoolAttr(false);
187
188 auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
189 Block::iterator insertion_point, Value arg,
190 int i) {
191 if (auto shaped = input_type.dyn_cast<ShapedType>()) {
192 if (shaped.getElementType().isa<FloatType>()) {
193 // If there are existing quantize ops, they are from training and we
194 // should respect them.
195 if (arg.hasOneUse() &&
196 llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
197 return;
198 }
199
200 auto min_max = GetMinMaxValuesForArgument(func_name, i);
201 // The input min/max or mean/std are not specified, then skip.
202 if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
203
204 TypeAttr params = quant::GetQuantizedTypeAttr(
205 builder, input_type,
206 builder.getF64FloatAttr(min_max.first.getValue()),
207 builder.getF64FloatAttr(min_max.second.getValue()),
208 /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
209 builder.setInsertionPoint(block, insertion_point);
210 auto q_op =
211 builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
212 auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
213 q_op.getResult());
214 arg.replaceAllUsesWith(dq_op.getResult());
215 q_op.setOperand(arg);
216 }
217 }
218 };
219
220 for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
221 BlockArgument arg = func.getArgument(i);
222 auto* arg_block = arg.getOwner();
223 add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
224 std::next(arg_block->begin(), i), arg, i);
225 }
226
227 return false;
228 }
229
230 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
231
RemoveRedundantStats(FuncOp func)232 bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
233 return RemoveRedundantStatsOps(func, GetOpQuantSpec);
234 }
235
Quantized(Operation * user)236 static Value Quantized(Operation* user) {
237 if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
238 if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
239 *q.getResult().user_begin())) {
240 return dq.getResult();
241 }
242 }
243 return {};
244 }
245
SanityCheckAndAdjustment(FuncOp func)246 void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
247 // If an op output has two users: one of them is a quantize op and another
248 // one is returned directly, we decide to return the quantized result instead,
249 // so this op can be quantized. This is only applied on the returned result
250 // because the error will not be accumulated.
251
252 func.walk([&](ReturnOp ret) {
253 int i = 0;
254 for (Value returned : ret.operands()) {
255 llvm::SmallVector<Value, 4> quantized;
256 for (auto user : returned.getUsers()) {
257 if (auto q = Quantized(user)) {
258 quantized.push_back(q);
259 }
260 }
261 if (quantized.size() == 1) {
262 ret.setOperand(i, quantized.front());
263 }
264 i++;
265 }
266 });
267
268 // We prefer to placing quantization emulation ops on the results of the
269 // concat ops.
270 func.walk([&](ConcatenationOp concat) {
271 if (concat.output().hasOneUse() &&
272 Quantized(*concat.output().user_begin())) {
273 return;
274 }
275 concat.emitWarning(
276 "Missing quantization parameter on the output might introduce "
277 "quantization error!");
278 });
279
280 // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
281 // eliminated at this point. This only occurs for the pattern
282 // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA
283 // where the qdq pair denotes a non-trivial requantization of an
284 // already quantized value. Since this makes little sense (directly quantizing
285 // (Quant $in, $qA) would introduce less quantization noise) the likely cause
286 // is an minor error in constructing the original network model that
287 // introduced back-to-back Fake Quantization operations. Hence: emit a
288 // warning. N.b. at this point we're (teporarility) in the quantization
289 // dialect (presumably enable re-use in xla etc) quant::*QuantizeCastOp
290 // we're matching here.
291 //
292 func.walk([&](quant::QuantizeCastOp q_op) {
293 // If up with end up with
294 auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
295 q_op.getOperand().getDefiningOp());
296 if (!dq_op) {
297 return;
298 }
299 auto dq_arg = dq_op.getOperand();
300
301 if (!dq_arg.hasOneUse()) {
302 // The initial quantization is used someplace else ... so it might be
303 // reasonable for it to requantized for another purpose.
304 // Ideally would want to still check whether requantization narrows
305 // rather than widens the representation.
306 return;
307 }
308
309 // Invariant:
310 // isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
311 // getdq_arg.getType() != q_op.getResult().getType()
312 //
313 // as otherwise qdq pair would have been optimized away.
314 auto qd_arg_def_q_op =
315 dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
316 if (!qd_arg_def_q_op) {
317 return;
318 }
319
320 qd_arg_def_q_op.emitWarning()
321 << " quantizer's output has another quantizer (" << q_op.getLoc()
322 << ") as consumer - intentional?";
323 });
324 }
325
ContainsQuantizeOps(FuncOp func)326 bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
327 for (const auto& op : func.getOps()) {
328 if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
329 }
330 return false;
331 }
332
333 using PrepareQuantStats =
334 quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
335
runOnFunction()336 void PrepareQuantizePass::runOnFunction() {
337 FuncOp func = getFunction();
338 MLIRContext* ctx = func.getContext();
339 ConvertTFLQuantOpsToMlirQuantOps(func);
340
341 if (quant_specs_.post_training_quantization) {
342 tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
343 RemoveRedundantStats(func);
344 } else {
345 tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
346 // Set the quantization parameters for the quantizable input nodes. If this
347 // failed, return the function immediately. This is only required for
348 // quantization aware training model conversion.
349 if (SetInputNodesQuantizationParams(func)) {
350 return;
351 }
352 }
353
354 bool is_signed = quant_specs_.IsSignedInferenceType();
355 int bit_width = quant_specs_.GetQuantizationTypeWidth();
356 // When this is true, the quantizer will try its best to extract the
357 // quantization parameters from the op quantization property and constant
358 // content. This is also set to true when the `quantize_allowlist` and
359 // `quantize_signed` test flags are enabled.
360 bool eager_quantize = ContainsQuantizeOps(func) ||
361 (!quantize_allowlist.empty() || quantize_signed);
362 // Infer the tensor range for the activation ops and weight constants unless
363 // it is disabled explicitly.
364 bool infer_tensor_range =
365 (quant_specs_.post_training_quantization || eager_quantize) &&
366 !quant_specs_.disable_infer_tensor_range;
367
368 // LSTM's restrict_scale requirement should be handled before converting stats
369 // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
370 // consistent. Otherwise some FileCheck tests would fail.
371 OwningRewritePatternList patterns_1;
372 if (quant_specs_.post_training_quantization) {
373 patterns_1.insert<PrepareLstmOutputScale<LSTMOp>>(ctx);
374 patterns_1.insert<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(
375 ctx);
376 }
377 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
378
379 // During the legalization, unsigned quantized type is used, so we have to
380 // convert all of them to signed.
381 OwningRewritePatternList patterns_2;
382 if (is_signed) {
383 patterns_2.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(
384 ctx);
385 // Convert quant stats to int8 quantization parameters.
386 // Currently, only activation stats are imported, so narrow_range = false.
387 patterns_2.insert<PrepareQuantStats>(bit_width, false, true,
388 quant_specs_.legacy_float_scale, ctx);
389 } else {
390 // Convert quant stats to uint8 quantization parameters.
391 // Currently, only activation stats are imported, so narrow_range = false.
392 patterns_2.insert<PrepareQuantStats>(bit_width, false, false,
393 quant_specs_.legacy_float_scale, ctx);
394 }
395
396 if (quant_specs_.post_training_quantization) {
397 patterns_2.insert<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
398 patterns_2.insert<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
399 ctx, quant_specs_);
400 patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
401 }
402 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
403
404 SanityCheckAndAdjustment(func);
405
406 // Finally, the quantization parameters can be propagated to the rest of the
407 // values (tensors).
408 ApplyQuantizationParamsPropagation(
409 func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
410 GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
411
412 ConvertMlirQuantOpsToTFLQuantOps(func);
413 }
414
415 } // namespace
416
417 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const QuantizationSpecs & quant_specs)418 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
419 const QuantizationSpecs& quant_specs) {
420 return std::make_unique<PrepareQuantizePass>(quant_specs);
421 }
422
423 static PassRegistration<PrepareQuantizePass> pass(
424 "tfl-prepare-quantize", "Prepare TFL dialect for quantization");
425
426 } // namespace TFL
427 } // namespace mlir
428