1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies quantization on TFLite dialect.
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/Debug.h"
22 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Matchers.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/OperationSupport.h" // from @llvm-project
32 #include "mlir/IR/PatternMatch.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
41
42 // NOLINTNEXTLINE
43 static llvm::cl::opt<bool> enable_numeric_verify(
44 "tfl-numeric-verify", llvm::cl::value_desc("bool"),
45 llvm::cl::desc("Whether verify numericals at runtime."),
46 llvm::cl::init(false));
47
48 // NOLINTNEXTLINE
49 static llvm::cl::opt<float> error_tolerance(
50 "tfl-error-tolerance", llvm::cl::value_desc("float"),
51 llvm::cl::desc("Error tolerance for numeric verify. Valid when "
52 "`-tfl-numeric-verify` is set."),
53 llvm::cl::init(5.0));
54
55 // NOLINTNEXTLINE
56 static llvm::cl::opt<bool> enable_single_layer_verify(
57 "tfl-single-layer-verify", llvm::cl::value_desc("bool"),
58 llvm::cl::desc("Whether verify numericals layer by layer. Valid when "
59 "`-tfl-numeric-verify` is set."),
60 llvm::cl::init(true));
61
62 // NOLINTNEXTLINE
63 static llvm::cl::opt<bool> enable_log_if_failed(
64 "tfl-log-if-failed", llvm::cl::value_desc("bool"),
65 llvm::cl::desc("Whether verify numericals with thresholding "
66 "tolerance. Valid when `-tfl-numeric-verify` is set."),
67 llvm::cl::init(false));
68
69 namespace mlir {
70 namespace TFL {
71
72 //===----------------------------------------------------------------------===//
73 // The actual Quantize Pass.
74 //
75 namespace {
76
77 // Full integer quantization rewrite pattern for TFLite.
78 struct TFLFullQuantization
79 : public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
80 DequantizeOp, NumericVerifyOp> {
TFLFullQuantizationmlir::TFL::__anon24c57aa90111::TFLFullQuantization81 explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
82 float tolerance, bool verify_single_layer,
83 bool log_if_failed_flag = false)
84 : BaseType(ctx, verify_numeric_flag, tolerance, verify_single_layer,
85 log_if_failed_flag) {}
AllowHybridOperandmlir::TFL::__anon24c57aa90111::TFLFullQuantization86 static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon24c57aa90111::TFLFullQuantization87 static bool AllowHybridResult() { return false; }
88 };
89
90 struct LegacyQuantizePass : public OpRewritePattern<QuantizeOp> {
91 // This pattern should be applied before existing quantize pattern in
92 // `quantize_patterns.td`, so the benefit is set to some value larger than 1.
LegacyQuantizePassmlir::TFL::__anon24c57aa90111::LegacyQuantizePass93 explicit LegacyQuantizePass(MLIRContext* context)
94 : OpRewritePattern<QuantizeOp>(context, /*benefit=*/10) {}
matchAndRewritemlir::TFL::__anon24c57aa90111::LegacyQuantizePass95 LogicalResult matchAndRewrite(QuantizeOp op,
96 PatternRewriter& rewriter) const override {
97 DenseFPElementsAttr attr;
98 if (matchPattern(op.input(), m_Constant(&attr))) {
99 auto qtype = op.qtypeAttr();
100 if (auto quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue())) {
101 rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
102 return success();
103 }
104 }
105 return failure();
106 }
107 };
108
109 // Applies quantization on the model in TFL dialect.
110 struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
111 public:
112 // Constructor used by manually creating the pass.
QuantizePassmlir::TFL::__anon24c57aa90111::QuantizePass113 explicit QuantizePass(bool verify_numeric_flag = false,
114 bool legacy_float_scale = false)
115 : verify_numeric(verify_numeric_flag),
116 legacy_float_scale(legacy_float_scale) {}
117
118 void runOnFunction() override;
119
120 private:
121 bool verify_numeric;
122 bool legacy_float_scale;
123 };
124
125 #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
126
runOnFunction()127 void QuantizePass::runOnFunction() {
128 OwningRewritePatternList patterns;
129 auto func = getFunction();
130 auto* ctx = func.getContext();
131 if (legacy_float_scale) {
132 patterns.insert<LegacyQuantizePass>(ctx);
133 }
134 TFL::populateWithGenerated(ctx, patterns);
135 patterns.insert<TFLFullQuantization>(
136 ctx, enable_numeric_verify || verify_numeric, error_tolerance,
137 enable_single_layer_verify, enable_log_if_failed);
138 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
139 }
140 } // namespace
141
142 // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
CreateQuantizePass(bool verify_numeric,bool legacy_float_scale)143 std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
144 bool verify_numeric, bool legacy_float_scale) {
145 return std::make_unique<QuantizePass>(verify_numeric, legacy_float_scale);
146 }
147
148 static PassRegistration<QuantizePass> pass(
149 "tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
150
151 } // namespace TFL
152 } // namespace mlir
153