1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines node specs for quantization and the methods to parse 17 // command line flags to these specs. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 21 22 #include <string> 23 #include <vector> 24 25 #include "absl/strings/string_view.h" 26 #include "llvm/ADT/Optional.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 30 namespace mlir { 31 namespace TFL { 32 33 struct QuantizationSpecs { 34 // Which function this node quant specifications belong to. 35 std::string target_func = "main"; 36 37 // Whether allow weight-only quantization. This is the easiest quantization 38 // mode which doesn't require QAT or sample inputs. But it can only target 39 // DT_HALF and DT_QINT8 inference type. 40 bool weight_quantization = false; 41 42 // Whether the quantization passes are triggered for post-training 43 // quantization. If it is true, the model input doesn't require user specified 44 // input ranges. 45 // TODO(fengliuai): The `weight_quantization` is just a special case of 46 // post-training quantization. We need to deprecate the `weight_quantization`. 47 bool post_training_quantization = false; 48 49 // The node type when the model is exported. Currently this is limited to 50 // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the 51 // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, 52 // the `weight_quantization` flag needs to set to false. 53 tensorflow::DataType inference_type = tensorflow::DT_FLOAT; 54 55 // The input and output data type during inference. This flag is only used 56 // when `inference_type` is different from DT_FLOAT. This flag can only be set 57 // to DT_FLOAT or as same as `inference_type`. If this flag is different 58 // from `inference_type`, adaptor ops are inserted as heading and tailing ops 59 // in the result model. 60 tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT; 61 62 // Input node ranges. These ranges are stored as the same order of function 63 // arguments. They are only used when `weight_quantization` is set to false, 64 // and the model is required to have quantization parameters, either from 65 // quantization aware training or calibration, for the remaining tensors. 66 std::vector<std::pair<double, double>> input_ranges; 67 68 // The default ranges can be used when a tensor doesn't have quantization 69 // parameters and couldn't be quantized. Used only for latency tests. 70 std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges; 71 72 // A serialized "QuantizationInfo" object to specify value ranges for some of 73 // the tensors with known names. 74 std::string serialized_quant_stats = ""; 75 76 // Whether run the passes to propagate the quantization parameters and graph 77 // rewrites. Returns false if the inference_type is DT_FLOAT or 78 // `weight_quantization` flag is set. RunPropagationAndRewriteQuantizationPassesQuantizationSpecs79 bool RunPropagationAndRewriteQuantizationPasses() const { 80 return inference_type != tensorflow::DT_FLOAT && !weight_quantization; 81 } 82 83 // Whether run the passes to only quantize the weights. RunWeightQuantizationQuantizationSpecs84 bool RunWeightQuantization() const { return weight_quantization; } 85 86 // Whether this inference type represents a signed storage type. IsSignedInferenceTypeQuantizationSpecs87 bool IsSignedInferenceType() { 88 switch (inference_type) { 89 case tensorflow::DT_QUINT8: 90 case tensorflow::DT_QUINT16: 91 return false; 92 default: 93 return true; 94 } 95 } 96 97 // Gets the width of this quantization type. Returns 0 if it isn't a 98 // quantization type. GetQuantizationTypeWidthQuantizationSpecs99 int64_t GetQuantizationTypeWidth() { 100 switch (inference_type) { 101 case tensorflow::DT_QINT8: 102 case tensorflow::DT_QUINT8: 103 return 8; 104 case tensorflow::DT_QINT16: 105 case tensorflow::DT_QUINT16: 106 return 16; 107 case tensorflow::DT_QINT32: 108 return 32; 109 default: 110 return 0; 111 } 112 } 113 }; 114 115 // Parses the command line flag strings to the quantization specification for 116 // input arrays of a graph. The array names are not stored in the spec, and will 117 // be matched by position. Returns true if failed. 118 bool ParseInputNodeQuantSpecs(absl::string_view node_names, 119 absl::string_view min_values, 120 absl::string_view max_values, 121 absl::string_view inference_type, 122 QuantizationSpecs* quant_specs); 123 124 // Gets the quantization specification for input arrays. The array names are not 125 // stored in the spec, and will be matched by position. The min/max will be 126 // ignored if the inference_type isn't a quantized type. Returns true if failed. 127 bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names, 128 const std::vector<double>& node_mins, 129 const std::vector<double>& node_maxs, 130 tensorflow::DataType inference_type, 131 QuantizationSpecs* quant_specs); 132 133 } // namespace TFL 134 } // namespace mlir 135 136 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ 137