• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies quantization on TFLite dialect.
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/Debug.h"
22 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Matchers.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
32 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
41 
42 // NOLINTNEXTLINE
43 static llvm::cl::opt<bool> enable_numeric_verify(
44     "tfl-numeric-verify", llvm::cl::value_desc("bool"),
45     llvm::cl::desc("Whether verify numericals at runtime."),
46     llvm::cl::init(false));
47 
48 // NOLINTNEXTLINE
49 static llvm::cl::opt<float> error_tolerance(
50     "tfl-error-tolerance", llvm::cl::value_desc("float"),
51     llvm::cl::desc("Error tolerance for numeric verify. Valid when "
52                    "`-tfl-numeric-verify` is set."),
53     llvm::cl::init(5.0));
54 
55 // NOLINTNEXTLINE
56 static llvm::cl::opt<bool> enable_single_layer_verify(
57     "tfl-single-layer-verify", llvm::cl::value_desc("bool"),
58     llvm::cl::desc("Whether verify numericals layer by layer. Valid when "
59                    "`-tfl-numeric-verify` is set."),
60     llvm::cl::init(true));
61 
62 // NOLINTNEXTLINE
63 static llvm::cl::opt<bool> enable_log_if_failed(
64     "tfl-log-if-failed", llvm::cl::value_desc("bool"),
65     llvm::cl::desc("Whether verify numericals with thresholding "
66                    "tolerance. Valid when `-tfl-numeric-verify` is set."),
67     llvm::cl::init(false));
68 
69 namespace mlir {
70 namespace TFL {
71 
72 //===----------------------------------------------------------------------===//
73 // The actual Quantize Pass.
74 //
75 namespace {
76 
77 // Full integer quantization rewrite pattern for TFLite.
78 struct TFLFullQuantization
79     : public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
80                                         DequantizeOp, NumericVerifyOp> {
TFLFullQuantizationmlir::TFL::__anon24c57aa90111::TFLFullQuantization81   explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
82                                float tolerance, bool verify_single_layer,
83                                bool log_if_failed_flag = false)
84       : BaseType(ctx, verify_numeric_flag, tolerance, verify_single_layer,
85                  log_if_failed_flag) {}
AllowHybridOperandmlir::TFL::__anon24c57aa90111::TFLFullQuantization86   static bool AllowHybridOperand() { return false; }
AllowHybridResultmlir::TFL::__anon24c57aa90111::TFLFullQuantization87   static bool AllowHybridResult() { return false; }
88 };
89 
90 struct LegacyQuantizePass : public OpRewritePattern<QuantizeOp> {
91   // This pattern should be applied before existing quantize pattern in
92   // `quantize_patterns.td`, so the benefit is set to some value larger than 1.
LegacyQuantizePassmlir::TFL::__anon24c57aa90111::LegacyQuantizePass93   explicit LegacyQuantizePass(MLIRContext* context)
94       : OpRewritePattern<QuantizeOp>(context, /*benefit=*/10) {}
matchAndRewritemlir::TFL::__anon24c57aa90111::LegacyQuantizePass95   LogicalResult matchAndRewrite(QuantizeOp op,
96                                 PatternRewriter& rewriter) const override {
97     DenseFPElementsAttr attr;
98     if (matchPattern(op.input(), m_Constant(&attr))) {
99       auto qtype = op.qtypeAttr();
100       if (auto quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue())) {
101         rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
102         return success();
103       }
104     }
105     return failure();
106   }
107 };
108 
109 // Applies quantization on the model in TFL dialect.
110 struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
111  public:
112   // Constructor used by manually creating the pass.
QuantizePassmlir::TFL::__anon24c57aa90111::QuantizePass113   explicit QuantizePass(bool verify_numeric_flag = false,
114                         bool legacy_float_scale = false)
115       : verify_numeric(verify_numeric_flag),
116         legacy_float_scale(legacy_float_scale) {}
117 
118   void runOnFunction() override;
119 
120  private:
121   bool verify_numeric;
122   bool legacy_float_scale;
123 };
124 
125 #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
126 
runOnFunction()127 void QuantizePass::runOnFunction() {
128   OwningRewritePatternList patterns;
129   auto func = getFunction();
130   auto* ctx = func.getContext();
131   if (legacy_float_scale) {
132     patterns.insert<LegacyQuantizePass>(ctx);
133   }
134   TFL::populateWithGenerated(ctx, patterns);
135   patterns.insert<TFLFullQuantization>(
136       ctx, enable_numeric_verify || verify_numeric, error_tolerance,
137       enable_single_layer_verify, enable_log_if_failed);
138   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
139 }
140 }  // namespace
141 
142 // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
CreateQuantizePass(bool verify_numeric,bool legacy_float_scale)143 std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
144     bool verify_numeric, bool legacy_float_scale) {
145   return std::make_unique<QuantizePass>(verify_numeric, legacy_float_scale);
146 }
147 
148 static PassRegistration<QuantizePass> pass(
149     "tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
150 
151 }  // namespace TFL
152 }  // namespace mlir
153