1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies quantization on TFLite dialect.
17
18 #include <cstddef>
19 #include <string>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Matchers.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/OperationSupport.h" // from @llvm-project
35 #include "mlir/IR/PatternMatch.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
44
45 // NOLINTNEXTLINE
46 static llvm::cl::opt<bool> enable_numeric_verify(
47 "tfl-numeric-verify", llvm::cl::value_desc("bool"),
48 llvm::cl::desc("Whether verify numericals at runtime."),
49 llvm::cl::init(false));
50
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<float> error_tolerance(
53 "tfl-error-tolerance", llvm::cl::value_desc("float"),
54 llvm::cl::desc("Error tolerance for numeric verify. Valid when "
55 "`-tfl-numeric-verify` is set."),
56 llvm::cl::init(5.0));
57
58 // NOLINTNEXTLINE
59 static llvm::cl::opt<bool> enable_whole_model_verify(
60 "tfl-whole-model-verify", llvm::cl::value_desc("bool"),
61 llvm::cl::desc("Whether verify numericals layer by layer or whole model. "
62 "Valid when `-tfl-numeric-verify` is set."),
63 llvm::cl::init(false));
64
65 // NOLINTNEXTLINE
66 static llvm::cl::opt<bool> enable_log_if_failed(
67 "tfl-log-if-failed", llvm::cl::value_desc("bool"),
68 llvm::cl::desc("Whether verify numericals with thresholding "
69 "tolerance. Valid when `-tfl-numeric-verify` is set."),
70 llvm::cl::init(false));
71
72 // NOLINTNEXTLINE
73 static llvm::cl::opt<bool> enable_legacy_quantize(
74 "tfl-legacy-quantize", llvm::cl::value_desc("bool"),
75 llvm::cl::desc("Use legacy quantize mode in test. Valid when"
76 "`-tfl-legacy-quantize` is set."),
77 llvm::cl::init(false));
78
79 // NOLINTNEXTLINE
80 static llvm::cl::list<std::string> ops_blocklist_flag(
81 "tfl-ops-blocklist",
82 llvm::cl::desc("Names of ops to blocklist from quantization"),
83 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
84
85 // NOLINTNEXTLINE
86 static llvm::cl::list<std::string> nodes_blocklist_flag(
87 "tfl-locs-blocklist",
88 llvm::cl::desc("Names of location to blocklist from quantization"),
89 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
90
91 namespace mlir {
92 namespace TFL {
93
94 //===----------------------------------------------------------------------===//
95 // The actual Quantize Pass.
96 //
97 namespace {
98
99 // Full integer quantization rewrite pattern using DQ as the root op.
100 struct TFLFullQuantization
101 : public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
102 DequantizeOp, NumericVerifyOp> {
TFLFullQuantizationmlir::TFL::__anon47583dee0111::TFLFullQuantization103 explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
104 float tolerance, bool verify_whole_model,
105 bool log_if_failed_flag = false,
106 const StringSet& ops_blocklist_flag = {},
107 const StringSet& nodes_blocklist_flag = {})
108 : BaseType(ctx, verify_numeric_flag, tolerance, verify_whole_model,
109 log_if_failed_flag, ops_blocklist_flag, nodes_blocklist_flag) {
110 }
AllowHybridOperandmlir::TFL::__anon47583dee0111::TFLFullQuantization111 static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon47583dee0111::TFLFullQuantization112 static bool AllowHybridResult() { return false; }
113 };
114
115 // Full integer quantization rewrite pattern using Q as the root op. This is for
116 // the quantizable ops without floating-point operands.
117 struct TFLFullQuantizationReverse
118 : public quant::QuantizationPattern<TFLFullQuantizationReverse, QuantizeOp,
119 DequantizeOp, NumericVerifyOp,
120 QuantizeOp> {
TFLFullQuantizationReversemlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse121 explicit TFLFullQuantizationReverse(
122 MLIRContext* ctx, bool verify_numeric_flag, float tolerance,
123 bool verify_whole_model, bool log_if_failed_flag = false,
124 const StringSet& ops_blocklist_flag = {},
125 const StringSet& nodes_blocklist_flag = {})
126 : BaseType(ctx, verify_numeric_flag, tolerance, verify_whole_model,
127 log_if_failed_flag, ops_blocklist_flag, nodes_blocklist_flag) {
128 }
AllowHybridOperandmlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse129 static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse130 static bool AllowHybridResult() { return false; }
131 };
132
133 struct QuantizeConstPattern : public OpRewritePattern<QuantizeOp> {
QuantizeConstPatternmlir::TFL::__anon47583dee0111::QuantizeConstPattern134 explicit QuantizeConstPattern(MLIRContext* context, bool legacy_float_scale)
135 : OpRewritePattern<QuantizeOp>(context),
136 legacy_float_scale(legacy_float_scale) {}
matchAndRewritemlir::TFL::__anon47583dee0111::QuantizeConstPattern137 LogicalResult matchAndRewrite(QuantizeOp op,
138 PatternRewriter& rewriter) const override {
139 DenseFPElementsAttr attr;
140 if (matchPattern(op.input(), m_Constant(&attr))) {
141 auto qtype = op.qtypeAttr();
142 Attribute quantized_attr;
143 if (legacy_float_scale) {
144 quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue());
145 } else {
146 quantized_attr = quant::Quantize(attr, qtype.getValue());
147 }
148 if (quantized_attr) {
149 rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
150 return success();
151 }
152 }
153 return failure();
154 }
155
156 private:
157 bool legacy_float_scale;
158 };
159
160 #define LIST_FLAG_OR_STRING_SET(list, set) \
161 (!list.empty() ? StringSet(list.begin(), list.end()) : set)
162
163 // Applies quantization on the model in TFL dialect.
164 struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
165 public:
166 // Constructor used by manually creating the pass.
QuantizePassmlir::TFL::__anon47583dee0111::QuantizePass167 explicit QuantizePass(bool verify_numeric_flag = false,
168 bool verify_whole_model = true,
169 bool legacy_float_scale = false,
170 const StringSet& ops_blocklist_set = {},
171 const StringSet& nodes_blocklist_set = {})
172 : verify_numeric(verify_numeric_flag),
173 verify_whole_model(verify_whole_model),
174 legacy_float_scale(legacy_float_scale),
175 ops_blocklist(
176 LIST_FLAG_OR_STRING_SET(ops_blocklist_flag, ops_blocklist_set)),
177 nodes_blocklist(LIST_FLAG_OR_STRING_SET(nodes_blocklist_flag,
178 nodes_blocklist_set)) {}
179
getArgumentmlir::TFL::__anon47583dee0111::QuantizePass180 StringRef getArgument() const final {
181 // This is the argument used to refer to the pass in
182 // the textual format (on the commandline for example).
183 return "tfl-quantize";
184 }
getDescriptionmlir::TFL::__anon47583dee0111::QuantizePass185 StringRef getDescription() const final {
186 // This is a brief description of the pass.
187 return "Apply quantization on models in TensorFlow Lite dialect";
188 }
189
190 void runOnFunction() override;
191
192 private:
193 bool verify_numeric;
194 bool verify_whole_model;
195 bool legacy_float_scale;
196 const StringSet ops_blocklist;
197 const StringSet nodes_blocklist;
198 };
199
200 #undef LIST_FLAG_OR_STRING_SET
201
202 #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
203
runOnFunction()204 void QuantizePass::runOnFunction() {
205 OwningRewritePatternList patterns(&getContext());
206 auto func = getFunction();
207 auto* ctx = func.getContext();
208
209 TFL::populateWithGenerated(patterns);
210 patterns.insert<TFLFullQuantization, TFLFullQuantizationReverse>(
211 ctx, enable_numeric_verify || verify_numeric, error_tolerance,
212 enable_whole_model_verify || verify_whole_model, enable_log_if_failed,
213 ops_blocklist, nodes_blocklist);
214 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
215
216 // Constant quantization is a lossy transformation, so they are applied only
217 // after all the other patterns have been aplied.
218 OwningRewritePatternList patterns_2(&getContext());
219 patterns_2.insert<QuantizeConstPattern>(
220 ctx, legacy_float_scale || enable_legacy_quantize);
221 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
222 }
223 } // namespace
224
225 // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
CreateQuantizePass(bool verify_numeric,bool whole_model_verify,bool legacy_float_scale,const StringSet & ops_blocklist,const StringSet & nodes_blocklist)226 std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
227 bool verify_numeric, bool whole_model_verify, bool legacy_float_scale,
228 const StringSet& ops_blocklist, const StringSet& nodes_blocklist) {
229 return std::make_unique<QuantizePass>(verify_numeric, whole_model_verify,
230 legacy_float_scale, ops_blocklist,
231 nodes_blocklist);
232 }
233
234 static PassRegistration<QuantizePass> pass;
235
236 } // namespace TFL
237 } // namespace mlir
238