• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines node specs for quantization and the methods to parse
17 // command line flags to these specs.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
21 
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 
30 namespace mlir {
31 namespace TFL {
32 
33 struct QuantizationSpecs {
34   // Which function this node quant specifications belong to.
35   std::string target_func = "main";
36 
37   // Whether allow weight-only quantization. This is the easiest quantization
38   // mode which doesn't require QAT or sample inputs. But it can only target
39   // DT_HALF and DT_QINT8 inference type.
40   bool weight_quantization = false;
41 
42   // Whether the quantization passes are triggered for post-training
43   // quantization. If it is true, the model input doesn't require user specified
44   // input ranges.
45   // TODO(fengliuai): The `weight_quantization` is just a special case of
46   // post-training quantization. We need to deprecate the `weight_quantization`.
47   bool post_training_quantization = false;
48 
49   // The node type when the model is exported. Currently this is limited to
50   // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
51   // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
52   // the `weight_quantization` flag needs to set to false.
53   tensorflow::DataType inference_type = tensorflow::DT_FLOAT;
54 
55   // The input and output data type during inference. This flag is only used
56   // when `inference_type` is different from DT_FLOAT. This flag can only be set
57   // to DT_FLOAT or as same as `inference_type`. If this flag is different
58   // from `inference_type`, adaptor ops are inserted as heading and tailing ops
59   // in the result model.
60   tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT;
61 
62   // Input node ranges. These ranges are stored as the same order of function
63   // arguments. They are only used when `weight_quantization` is set to false,
64   // and the model is required to have quantization parameters, either from
65   // quantization aware training or calibration, for the remaining tensors.
66   std::vector<std::pair<double, double>> input_ranges;
67 
68   // The default ranges can be used when a tensor doesn't have quantization
69   // parameters and couldn't be quantized. Used only for latency tests.
70   std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
71 
72   // A serialized "QuantizationInfo" object to specify value ranges for some of
73   // the tensors with known names.
74   std::string serialized_quant_stats = "";
75 
76   // Whether run the passes to propagate the quantization parameters and graph
77   // rewrites. Returns false if the inference_type is DT_FLOAT or
78   // `weight_quantization` flag is set.
RunPropagationAndRewriteQuantizationPassesQuantizationSpecs79   bool RunPropagationAndRewriteQuantizationPasses() const {
80     return inference_type != tensorflow::DT_FLOAT && !weight_quantization;
81   }
82 
83   // Whether run the passes to only quantize the weights.
RunWeightQuantizationQuantizationSpecs84   bool RunWeightQuantization() const { return weight_quantization; }
85 
86   // Whether this inference type represents a signed storage type.
IsSignedInferenceTypeQuantizationSpecs87   bool IsSignedInferenceType() {
88     switch (inference_type) {
89       case tensorflow::DT_QUINT8:
90       case tensorflow::DT_QUINT16:
91         return false;
92       default:
93         return true;
94     }
95   }
96 
97   // Gets the width of this quantization type. Returns 0 if it isn't a
98   // quantization type.
GetQuantizationTypeWidthQuantizationSpecs99   int64_t GetQuantizationTypeWidth() {
100     switch (inference_type) {
101       case tensorflow::DT_QINT8:
102       case tensorflow::DT_QUINT8:
103         return 8;
104       case tensorflow::DT_QINT16:
105       case tensorflow::DT_QUINT16:
106         return 16;
107       case tensorflow::DT_QINT32:
108         return 32;
109       default:
110         return 0;
111     }
112   }
113 };
114 
115 // Parses the command line flag strings to the quantization specification for
116 // input arrays of a graph. The array names are not stored in the spec, and will
117 // be matched by position. Returns true if failed.
118 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
119                               absl::string_view min_values,
120                               absl::string_view max_values,
121                               absl::string_view inference_type,
122                               QuantizationSpecs* quant_specs);
123 
124 // Gets the quantization specification for input arrays. The array names are not
125 // stored in the spec, and will be matched by position. The min/max will be
126 // ignored if the inference_type isn't a quantized type. Returns true if failed.
127 bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
128                             const std::vector<double>& node_mins,
129                             const std::vector<double>& node_maxs,
130                             tensorflow::DataType inference_type,
131                             QuantizationSpecs* quant_specs);
132 
133 }  // namespace TFL
134 }  // namespace mlir
135 
136 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
137