• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies quantization on TFLite dialect.
17 
18 #include <cstddef>
19 #include <string>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Matchers.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
35 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
44 
45 // NOLINTNEXTLINE
46 static llvm::cl::opt<bool> enable_numeric_verify(
47     "tfl-numeric-verify", llvm::cl::value_desc("bool"),
48     llvm::cl::desc("Whether verify numericals at runtime."),
49     llvm::cl::init(false));
50 
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<float> error_tolerance(
53     "tfl-error-tolerance", llvm::cl::value_desc("float"),
54     llvm::cl::desc("Error tolerance for numeric verify. Valid when "
55                    "`-tfl-numeric-verify` is set."),
56     llvm::cl::init(5.0));
57 
58 // NOLINTNEXTLINE
59 static llvm::cl::opt<bool> enable_whole_model_verify(
60     "tfl-whole-model-verify", llvm::cl::value_desc("bool"),
61     llvm::cl::desc("Whether verify numericals layer by layer or whole model. "
62                    "Valid when `-tfl-numeric-verify` is set."),
63     llvm::cl::init(false));
64 
65 // NOLINTNEXTLINE
66 static llvm::cl::opt<bool> enable_log_if_failed(
67     "tfl-log-if-failed", llvm::cl::value_desc("bool"),
68     llvm::cl::desc("Whether verify numericals with thresholding "
69                    "tolerance. Valid when `-tfl-numeric-verify` is set."),
70     llvm::cl::init(false));
71 
72 // NOLINTNEXTLINE
73 static llvm::cl::opt<bool> enable_legacy_quantize(
74     "tfl-legacy-quantize", llvm::cl::value_desc("bool"),
75     llvm::cl::desc("Use legacy quantize mode in test. Valid when"
76                    "`-tfl-legacy-quantize` is set."),
77     llvm::cl::init(false));
78 
79 // NOLINTNEXTLINE
80 static llvm::cl::list<std::string> ops_blocklist_flag(
81     "tfl-ops-blocklist",
82     llvm::cl::desc("Names of ops to blocklist from quantization"),
83     llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
84 
85 // NOLINTNEXTLINE
86 static llvm::cl::list<std::string> nodes_blocklist_flag(
87     "tfl-locs-blocklist",
88     llvm::cl::desc("Names of location to blocklist from quantization"),
89     llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
90 
91 namespace mlir {
92 namespace TFL {
93 
94 //===----------------------------------------------------------------------===//
95 // The actual Quantize Pass.
96 //
97 namespace {
98 
99 // Full integer quantization rewrite pattern using DQ as the root op.
100 struct TFLFullQuantization
101     : public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
102                                         DequantizeOp, NumericVerifyOp> {
TFLFullQuantizationmlir::TFL::__anon47583dee0111::TFLFullQuantization103   explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
104                                float tolerance, bool verify_whole_model,
105                                bool log_if_failed_flag = false,
106                                const StringSet& ops_blocklist_flag = {},
107                                const StringSet& nodes_blocklist_flag = {})
108       : BaseType(ctx, verify_numeric_flag, tolerance, verify_whole_model,
109                  log_if_failed_flag, ops_blocklist_flag, nodes_blocklist_flag) {
110   }
AllowHybridOperandmlir::TFL::__anon47583dee0111::TFLFullQuantization111   static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon47583dee0111::TFLFullQuantization112   static bool AllowHybridResult() { return false; }
113 };
114 
115 // Full integer quantization rewrite pattern using Q as the root op. This is for
116 // the quantizable ops without floating-point operands.
117 struct TFLFullQuantizationReverse
118     : public quant::QuantizationPattern<TFLFullQuantizationReverse, QuantizeOp,
119                                         DequantizeOp, NumericVerifyOp,
120                                         QuantizeOp> {
TFLFullQuantizationReversemlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse121   explicit TFLFullQuantizationReverse(
122       MLIRContext* ctx, bool verify_numeric_flag, float tolerance,
123       bool verify_whole_model, bool log_if_failed_flag = false,
124       const StringSet& ops_blocklist_flag = {},
125       const StringSet& nodes_blocklist_flag = {})
126       : BaseType(ctx, verify_numeric_flag, tolerance, verify_whole_model,
127                  log_if_failed_flag, ops_blocklist_flag, nodes_blocklist_flag) {
128   }
AllowHybridOperandmlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse129   static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon47583dee0111::TFLFullQuantizationReverse130   static bool AllowHybridResult() { return false; }
131 };
132 
133 struct QuantizeConstPattern : public OpRewritePattern<QuantizeOp> {
QuantizeConstPatternmlir::TFL::__anon47583dee0111::QuantizeConstPattern134   explicit QuantizeConstPattern(MLIRContext* context, bool legacy_float_scale)
135       : OpRewritePattern<QuantizeOp>(context),
136         legacy_float_scale(legacy_float_scale) {}
matchAndRewritemlir::TFL::__anon47583dee0111::QuantizeConstPattern137   LogicalResult matchAndRewrite(QuantizeOp op,
138                                 PatternRewriter& rewriter) const override {
139     DenseFPElementsAttr attr;
140     if (matchPattern(op.input(), m_Constant(&attr))) {
141       auto qtype = op.qtypeAttr();
142       Attribute quantized_attr;
143       if (legacy_float_scale) {
144         quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue());
145       } else {
146         quantized_attr = quant::Quantize(attr, qtype.getValue());
147       }
148       if (quantized_attr) {
149         rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
150         return success();
151       }
152     }
153     return failure();
154   }
155 
156  private:
157   bool legacy_float_scale;
158 };
159 
160 #define LIST_FLAG_OR_STRING_SET(list, set) \
161   (!list.empty() ? StringSet(list.begin(), list.end()) : set)
162 
163 // Applies quantization on the model in TFL dialect.
164 struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
165  public:
166   // Constructor used by manually creating the pass.
QuantizePassmlir::TFL::__anon47583dee0111::QuantizePass167   explicit QuantizePass(bool verify_numeric_flag = false,
168                         bool verify_whole_model = true,
169                         bool legacy_float_scale = false,
170                         const StringSet& ops_blocklist_set = {},
171                         const StringSet& nodes_blocklist_set = {})
172       : verify_numeric(verify_numeric_flag),
173         verify_whole_model(verify_whole_model),
174         legacy_float_scale(legacy_float_scale),
175         ops_blocklist(
176             LIST_FLAG_OR_STRING_SET(ops_blocklist_flag, ops_blocklist_set)),
177         nodes_blocklist(LIST_FLAG_OR_STRING_SET(nodes_blocklist_flag,
178                                                 nodes_blocklist_set)) {}
179 
getArgumentmlir::TFL::__anon47583dee0111::QuantizePass180   StringRef getArgument() const final {
181     // This is the argument used to refer to the pass in
182     // the textual format (on the commandline for example).
183     return "tfl-quantize";
184   }
getDescriptionmlir::TFL::__anon47583dee0111::QuantizePass185   StringRef getDescription() const final {
186     // This is a brief description of the pass.
187     return "Apply quantization on models in TensorFlow Lite dialect";
188   }
189 
190   void runOnFunction() override;
191 
192  private:
193   bool verify_numeric;
194   bool verify_whole_model;
195   bool legacy_float_scale;
196   const StringSet ops_blocklist;
197   const StringSet nodes_blocklist;
198 };
199 
200 #undef LIST_FLAG_OR_STRING_SET
201 
202 #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
203 
runOnFunction()204 void QuantizePass::runOnFunction() {
205   OwningRewritePatternList patterns(&getContext());
206   auto func = getFunction();
207   auto* ctx = func.getContext();
208 
209   TFL::populateWithGenerated(patterns);
210   patterns.insert<TFLFullQuantization, TFLFullQuantizationReverse>(
211       ctx, enable_numeric_verify || verify_numeric, error_tolerance,
212       enable_whole_model_verify || verify_whole_model, enable_log_if_failed,
213       ops_blocklist, nodes_blocklist);
214   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
215 
216   // Constant quantization is a lossy transformation, so they are applied only
217   // after all the other patterns have been aplied.
218   OwningRewritePatternList patterns_2(&getContext());
219   patterns_2.insert<QuantizeConstPattern>(
220       ctx, legacy_float_scale || enable_legacy_quantize);
221   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
222 }
223 }  // namespace
224 
225 // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
CreateQuantizePass(bool verify_numeric,bool whole_model_verify,bool legacy_float_scale,const StringSet & ops_blocklist,const StringSet & nodes_blocklist)226 std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
227     bool verify_numeric, bool whole_model_verify, bool legacy_float_scale,
228     const StringSet& ops_blocklist, const StringSet& nodes_blocklist) {
229   return std::make_unique<QuantizePass>(verify_numeric, whole_model_verify,
230                                         legacy_float_scale, ops_blocklist,
231                                         nodes_blocklist);
232 }
233 
234 static PassRegistration<QuantizePass> pass;
235 
236 }  // namespace TFL
237 }  // namespace mlir
238