• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines node specs for quantization and the methods to parse
17 // command line flags to these specs.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
21 
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 
30 namespace mlir {
31 namespace TFL {
32 
33 struct QuantizationSpecs {
34   // Which function this node quant specifications belong to.
35   std::string target_func = "main";
36 
37   // Whether allow weight-only quantization. This is the easiest quantization
38   // mode which doesn't require QAT or sample inputs. But it can only target
39   // DT_HALF and DT_QINT8 inference type.
40   bool weight_quantization = false;
41 
42   // Whether the quantization passes are triggered for post-training
43   // quantization. If it is true, the model input doesn't require user specified
44   // input ranges.
45   // TODO(fengliuai): The `weight_quantization` is just a special case of
46   // post-training quantization. We need to deprecate the `weight_quantization`.
47   bool post_training_quantization = false;
48 
49   // Calculate scales in float to keep quantized values the same with old TOCO
50   // quantizer.
51   bool legacy_float_scale = false;
52 
53   // When set to true, quantization will be done per-tensor. Currently, this
54   // option is only valid when the quantization parameters need to be created by
55   // scanning the constant content (post-training quantization or QAT without
56   // weight FakeQuant).
57   bool disable_per_channel = false;
58 
59   // When set to true, the fixed output ranges of the activation ops (tanh,
60   // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
61   // these ops, quantization emulation ops should be placed after the ops in the
62   // input graph. This flag should be set to false for post-training
63   // quantization.
64   bool disable_infer_tensor_range = false;
65 
66   // The node type when the model is exported. Currently this is limited to
67   // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
68   // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
69   // the `weight_quantization` flag needs to set to false.
70   tensorflow::DataType inference_type = tensorflow::DT_FLOAT;
71 
72   // The input and output data type during inference. This flag is only used
73   // when `inference_type` is different from DT_FLOAT. This flag can only be set
74   // to DT_FLOAT or as same as `inference_type`. If this flag is different
75   // from `inference_type`, adaptor ops are inserted as heading and tailing ops
76   // in the result model.
77   tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT;
78 
79   // Input node ranges. These ranges are stored as the same order of function
80   // arguments. They are only used when `weight_quantization` is set to false,
81   // and the model is required to have quantization parameters, either from
82   // quantization aware training or calibration, for the remaining tensors.
83   std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
84       input_ranges;
85 
86   // The default ranges can be used when a tensor doesn't have quantization
87   // parameters and couldn't be quantized. Used only for latency tests.
88   std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
89 
90   // A serialized "QuantizationInfo" object to specify value ranges for some of
91   // the tensors with known names.
92   std::string serialized_quant_stats = "";
93 
94   // Whether run the passes to propagate the quantization parameters and graph
95   // rewrites. Returns false if the inference_type is DT_FLOAT or
96   // `weight_quantization` flag is set.
RunPropagationAndRewriteQuantizationPassesQuantizationSpecs97   bool RunPropagationAndRewriteQuantizationPasses() const {
98     return inference_type != tensorflow::DT_FLOAT && !weight_quantization;
99   }
100 
101   // Whether run the passes to only quantize the weights.
RunWeightQuantizationQuantizationSpecs102   bool RunWeightQuantization() const { return weight_quantization; }
103 
104   // Whether this inference type represents a signed storage type.
IsSignedInferenceTypeQuantizationSpecs105   bool IsSignedInferenceType() const {
106     switch (inference_type) {
107       case tensorflow::DT_QUINT8:
108       case tensorflow::DT_QUINT16:
109         return false;
110       default:
111         return true;
112     }
113   }
114 
115   // Gets the width of this quantization type. Returns 0 if it isn't a
116   // quantization type.
GetQuantizationTypeWidthQuantizationSpecs117   int64_t GetQuantizationTypeWidth() const {
118     switch (inference_type) {
119       case tensorflow::DT_QINT8:
120       case tensorflow::DT_QUINT8:
121         return 8;
122       case tensorflow::DT_QINT16:
123       case tensorflow::DT_QUINT16:
124         return 16;
125       case tensorflow::DT_QINT32:
126         return 32;
127       default:
128         return 0;
129     }
130   }
131 
132   // Whether add the NumericVerify ops to verify numbers before and after
133   // quantization.
134   bool verify_numeric = false;
135 };
136 
137 // Parses the command line flag strings to the quantization specification for
138 // input arrays of a graph. The array names are not stored in the spec, and will
139 // be matched by position. Returns true if failed.
140 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
141                               absl::string_view min_values,
142                               absl::string_view max_values,
143                               absl::string_view inference_type,
144                               QuantizationSpecs* quant_specs);
145 
146 // Gets the quantization specification for input arrays. The array names are not
147 // stored in the spec, and will be matched by position. The min/max will be
148 // ignored if the inference_type isn't a quantized type. Returns true if failed.
149 bool GetInputNodeQuantSpecs(
150     const std::vector<std::string>& node_names,
151     const std::vector<llvm::Optional<double>>& node_mins,
152     const std::vector<llvm::Optional<double>>& node_maxs,
153     tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
154 
155 }  // namespace TFL
156 }  // namespace mlir
157 
158 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
159