• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines node specs for quantization and the methods to parse
17 // command line flags to these specs.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
21 
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
30 
31 namespace mlir {
32 namespace TFL {
33 
34 using ::tflite::optimize::ReducedPrecisionSupport;
35 
36 struct QuantizationSpecs {
37   // Which function this node quant specifications belong to.
38   std::string target_func = "main";
39 
40   // Whether allow weight-only quantization. This is the easiest quantization
41   // mode which doesn't require QAT or sample inputs. But it can only target
42   // DT_HALF and DT_QINT8 inference type.
43   bool weight_quantization = false;
44 
45   // Whether the quantization passes are triggered for post-training
46   // quantization. If it is true, the model input doesn't require user specified
47   // input ranges.
48   // TODO(fengliuai): The `weight_quantization` is just a special case of
49   // post-training quantization. We need to deprecate the `weight_quantization`.
50   bool post_training_quantization = false;
51 
52   // Calculate scales in float to keep quantized values the same with old TOCO
53   // quantizer.
54   bool legacy_float_scale = false;
55 
56   // When set to true, quantization will be done per-tensor. Currently, this
57   // option is only valid when the quantization parameters need to be created by
58   // scanning the constant content (post-training quantization or QAT without
59   // weight FakeQuant).
60   bool disable_per_channel = false;
61 
62   // When set to true, the fixed output ranges of the activation ops (tanh,
63   // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
64   // these ops, quantization emulation ops should be placed after the ops in the
65   // input graph. This flag should be set to false for post-training
66   // quantization.
67   bool disable_infer_tensor_range = false;
68 
69   // The node type when the model is exported. Currently this is limited to
70   // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
71   // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
72   // the `weight_quantization` flag needs to set to false.
73   tensorflow::DataType inference_type = tensorflow::DT_FLOAT;
74 
75   // The input and output data type during inference. This flag is only used
76   // when `inference_type` is different from DT_FLOAT. This flag can only be set
77   // to DT_FLOAT or as same as `inference_type`. If this flag is different
78   // from `inference_type`, adaptor ops are inserted as heading and tailing ops
79   // in the result model.
80   tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT;
81 
82   // Input node ranges. These ranges are stored as the same order of function
83   // arguments. They are only used when `weight_quantization` is set to false,
84   // and the model is required to have quantization parameters, either from
85   // quantization aware training or calibration, for the remaining tensors.
86   std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
87       input_ranges;
88 
89   // The default ranges can be used when a tensor doesn't have quantization
90   // parameters and couldn't be quantized. Used only for latency tests.
91   std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
92 
93   // A serialized "QuantizationInfo" object to specify value ranges for some of
94   // the tensors with known names.
95   std::string serialized_quant_stats = "";
96 
97   // A bitmask to encode support for reduced precision inference in the model.
98   ReducedPrecisionSupport support_mask = ReducedPrecisionSupport::None;
99   // Whether run the passes to propagate the quantization parameters and graph
100   // rewrites. Returns false if the inference_type is DT_FLOAT or
101   // `weight_quantization` flag is set.
RunPropagationAndRewriteQuantizationPassesQuantizationSpecs102   bool RunPropagationAndRewriteQuantizationPasses() const {
103     return inference_type != tensorflow::DT_FLOAT && !weight_quantization;
104   }
105 
106   // Whether run the passes to only quantize the weights.
RunWeightQuantizationQuantizationSpecs107   bool RunWeightQuantization() const { return weight_quantization; }
108 
109   // Whether this inference type represents a signed storage type.
IsSignedInferenceTypeQuantizationSpecs110   bool IsSignedInferenceType() const {
111     switch (inference_type) {
112       case tensorflow::DT_QUINT8:
113       case tensorflow::DT_QUINT16:
114         return false;
115       default:
116         return true;
117     }
118   }
119 
120   // Gets the width of this quantization type. Returns 0 if it isn't a
121   // quantization type.
GetQuantizationTypeWidthQuantizationSpecs122   int64_t GetQuantizationTypeWidth() const {
123     switch (inference_type) {
124       case tensorflow::DT_QINT8:
125       case tensorflow::DT_QUINT8:
126         return 8;
127       case tensorflow::DT_QINT16:
128       case tensorflow::DT_QUINT16:
129         return 16;
130       case tensorflow::DT_QINT32:
131         return 32;
132       default:
133         return 0;
134     }
135   }
136 
137   // Whether add the NumericVerify ops to verify numbers before and after
138   // quantization.
139   bool verify_numeric = false;
140   // Whether to add verification for layer by layer, or on whole model. When
141   // disabled (per-layer) float and quantized ops will be run from same input
142   // (output of previous quantized layer). When enabled, float and quantized ops
143   // will run with respective float and quantized output of previous ops.
144   bool whole_model_verify = false;
145 };
146 
147 // Parses the command line flag strings to the quantization specification for
148 // input arrays of a graph. The array names are not stored in the spec, and will
149 // be matched by position. Returns true if failed.
150 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
151                               absl::string_view min_values,
152                               absl::string_view max_values,
153                               absl::string_view inference_type,
154                               QuantizationSpecs* quant_specs);
155 
156 // Gets the quantization specification for input arrays. The array names are not
157 // stored in the spec, and will be matched by position. The min/max will be
158 // ignored if the inference_type isn't a quantized type. Returns true if failed.
159 bool GetInputNodeQuantSpecs(
160     const std::vector<std::string>& node_names,
161     const std::vector<llvm::Optional<double>>& node_mins,
162     const std::vector<llvm::Optional<double>>& node_maxs,
163     tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
164 
165 }  // namespace TFL
166 }  // namespace mlir
167 
168 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
169