1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines node specs for quantization and the methods to parse 17 // command line flags to these specs. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 21 22 #include <string> 23 #include <vector> 24 25 #include "absl/strings/string_view.h" 26 #include "llvm/ADT/Optional.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 30 namespace mlir { 31 namespace TFL { 32 33 struct QuantizationSpecs { 34 // Which function this node quant specifications belong to. 35 std::string target_func = "main"; 36 37 // Whether allow weight-only quantization. This is the easiest quantization 38 // mode which doesn't require QAT or sample inputs. But it can only target 39 // DT_HALF and DT_QINT8 inference type. 40 bool weight_quantization = false; 41 42 // Whether the quantization passes are triggered for post-training 43 // quantization. If it is true, the model input doesn't require user specified 44 // input ranges. 45 // TODO(fengliuai): The `weight_quantization` is just a special case of 46 // post-training quantization. We need to deprecate the `weight_quantization`. 47 bool post_training_quantization = false; 48 49 // Calculate scales in float to keep quantized values the same with old TOCO 50 // quantizer. 51 bool legacy_float_scale = false; 52 53 // When set to true, quantization will be done per-tensor. Currently, this 54 // option is only valid when the quantization parameters need to be created by 55 // scanning the constant content (post-training quantization or QAT without 56 // weight FakeQuant). 57 bool disable_per_channel = false; 58 59 // When set to true, the fixed output ranges of the activation ops (tanh, 60 // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize 61 // these ops, quantization emulation ops should be placed after the ops in the 62 // input graph. This flag should be set to false for post-training 63 // quantization. 64 bool disable_infer_tensor_range = false; 65 66 // The node type when the model is exported. Currently this is limited to 67 // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the 68 // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, 69 // the `weight_quantization` flag needs to set to false. 70 tensorflow::DataType inference_type = tensorflow::DT_FLOAT; 71 72 // The input and output data type during inference. This flag is only used 73 // when `inference_type` is different from DT_FLOAT. This flag can only be set 74 // to DT_FLOAT or as same as `inference_type`. If this flag is different 75 // from `inference_type`, adaptor ops are inserted as heading and tailing ops 76 // in the result model. 77 tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT; 78 79 // Input node ranges. These ranges are stored as the same order of function 80 // arguments. They are only used when `weight_quantization` is set to false, 81 // and the model is required to have quantization parameters, either from 82 // quantization aware training or calibration, for the remaining tensors. 83 std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>> 84 input_ranges; 85 86 // The default ranges can be used when a tensor doesn't have quantization 87 // parameters and couldn't be quantized. Used only for latency tests. 88 std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges; 89 90 // A serialized "QuantizationInfo" object to specify value ranges for some of 91 // the tensors with known names. 92 std::string serialized_quant_stats = ""; 93 94 // Whether run the passes to propagate the quantization parameters and graph 95 // rewrites. Returns false if the inference_type is DT_FLOAT or 96 // `weight_quantization` flag is set. RunPropagationAndRewriteQuantizationPassesQuantizationSpecs97 bool RunPropagationAndRewriteQuantizationPasses() const { 98 return inference_type != tensorflow::DT_FLOAT && !weight_quantization; 99 } 100 101 // Whether run the passes to only quantize the weights. RunWeightQuantizationQuantizationSpecs102 bool RunWeightQuantization() const { return weight_quantization; } 103 104 // Whether this inference type represents a signed storage type. IsSignedInferenceTypeQuantizationSpecs105 bool IsSignedInferenceType() const { 106 switch (inference_type) { 107 case tensorflow::DT_QUINT8: 108 case tensorflow::DT_QUINT16: 109 return false; 110 default: 111 return true; 112 } 113 } 114 115 // Gets the width of this quantization type. Returns 0 if it isn't a 116 // quantization type. GetQuantizationTypeWidthQuantizationSpecs117 int64_t GetQuantizationTypeWidth() const { 118 switch (inference_type) { 119 case tensorflow::DT_QINT8: 120 case tensorflow::DT_QUINT8: 121 return 8; 122 case tensorflow::DT_QINT16: 123 case tensorflow::DT_QUINT16: 124 return 16; 125 case tensorflow::DT_QINT32: 126 return 32; 127 default: 128 return 0; 129 } 130 } 131 132 // Whether add the NumericVerify ops to verify numbers before and after 133 // quantization. 134 bool verify_numeric = false; 135 }; 136 137 // Parses the command line flag strings to the quantization specification for 138 // input arrays of a graph. The array names are not stored in the spec, and will 139 // be matched by position. Returns true if failed. 140 bool ParseInputNodeQuantSpecs(absl::string_view node_names, 141 absl::string_view min_values, 142 absl::string_view max_values, 143 absl::string_view inference_type, 144 QuantizationSpecs* quant_specs); 145 146 // Gets the quantization specification for input arrays. The array names are not 147 // stored in the spec, and will be matched by position. The min/max will be 148 // ignored if the inference_type isn't a quantized type. Returns true if failed. 149 bool GetInputNodeQuantSpecs( 150 const std::vector<std::string>& node_names, 151 const std::vector<llvm::Optional<double>>& node_mins, 152 const std::vector<llvm::Optional<double>>& node_maxs, 153 tensorflow::DataType inference_type, QuantizationSpecs* quant_specs); 154 155 } // namespace TFL 156 } // namespace mlir 157 158 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 159