1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16
17 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
18 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
19 #include "tensorflow/lite/toco/model.h"
20 #include "tensorflow/lite/toco/tooling_util.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace toco {
24
InferQuantizedDataTypeFromFakeQuant(const FakeQuantOperator & op,ArrayDataType * out_quantized_data_type)25 bool InferQuantizedDataTypeFromFakeQuant(
26 const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
27 if (op.num_bits <= 8) {
28 *out_quantized_data_type = ArrayDataType::kUint8;
29 return true;
30 } else if (op.num_bits <= 16) {
31 *out_quantized_data_type = ArrayDataType::kInt16;
32 return true;
33 } else {
34 *out_quantized_data_type = ArrayDataType::kNone;
35 return false;
36 }
37 }
38
GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,double * out_min_value,double * out_max_value)39 bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
40 double* out_min_value,
41 double* out_max_value) {
42 switch (data_type) {
43 case ArrayDataType::kUint8:
44 *out_min_value = 0;
45 *out_max_value = 255;
46 return true;
47 case ArrayDataType::kInt16:
48 *out_min_value = -32768;
49 *out_max_value = 32767;
50 return true;
51 default:
52 return false;
53 }
54 }
55
GetQuantizedDataType(const Array & array,ArrayDataType default_type)56 ArrayDataType GetQuantizedDataType(const Array& array,
57 ArrayDataType default_type) {
58 switch (array.final_data_type) {
59 case ArrayDataType::kInt8:
60 case ArrayDataType::kUint8:
61 case ArrayDataType::kInt16:
62 case ArrayDataType::kUint16:
63 case ArrayDataType::kInt32:
64 case ArrayDataType::kUint32:
65 case ArrayDataType::kInt64:
66 case ArrayDataType::kUint64:
67 return array.final_data_type;
68 case ArrayDataType::kFloat:
69 case ArrayDataType::kNone:
70 return default_type;
71 default:
72 LOG(FATAL) << "Unhandled final quantization type "
73 << static_cast<int>(array.final_data_type);
74 }
75 }
76
77 template <ArrayDataType A>
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,QuantizationParams * quantization_params)78 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
79 const Array& array, QuantizationParams* quantization_params) {
80 *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
81 array.minmax->min, array.minmax->max, array.narrow_range);
82 }
83
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,ArrayDataType quantized_data_type,QuantizationParams * quantization_params)84 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
85 const Array& array, ArrayDataType quantized_data_type,
86 QuantizationParams* quantization_params) {
87 switch (quantized_data_type) {
88 case ArrayDataType::kInt8:
89 ChooseQuantizationParamsForArrayAndQuantizedDataType<
90 ArrayDataType::kInt8>(array, quantization_params);
91 break;
92 case ArrayDataType::kUint8:
93 ChooseQuantizationParamsForArrayAndQuantizedDataType<
94 ArrayDataType::kUint8>(array, quantization_params);
95 break;
96 case ArrayDataType::kInt16:
97 ChooseQuantizationParamsForArrayAndQuantizedDataType<
98 ArrayDataType::kInt16>(array, quantization_params);
99 break;
100 case ArrayDataType::kUint16:
101 ChooseQuantizationParamsForArrayAndQuantizedDataType<
102 ArrayDataType::kUint16>(array, quantization_params);
103 break;
104 case ArrayDataType::kInt32:
105 ChooseQuantizationParamsForArrayAndQuantizedDataType<
106 ArrayDataType::kInt32>(array, quantization_params);
107 break;
108 case ArrayDataType::kUint32:
109 ChooseQuantizationParamsForArrayAndQuantizedDataType<
110 ArrayDataType::kUint32>(array, quantization_params);
111 break;
112 case ArrayDataType::kInt64:
113 ChooseQuantizationParamsForArrayAndQuantizedDataType<
114 ArrayDataType::kInt64>(array, quantization_params);
115 break;
116 case ArrayDataType::kUint64:
117 ChooseQuantizationParamsForArrayAndQuantizedDataType<
118 ArrayDataType::kUint64>(array, quantization_params);
119 break;
120 case ArrayDataType::kFloat:
121 case ArrayDataType::kComplex64:
122 case ArrayDataType::kNone:
123 default:
124 LOG(FATAL) << "Unhandled final quantization type "
125 << static_cast<int>(quantized_data_type);
126 }
127 }
128
129 namespace {
130
131 template <ArrayDataType A>
QuantizeBuffer(const Array & array,const QuantizationParams & quantization_params)132 std::unique_ptr<GenericBuffer> QuantizeBuffer(
133 const Array& array, const QuantizationParams& quantization_params) {
134 const GenericBuffer& buffer = *array.buffer;
135 const auto inverse_scale = 1. / quantization_params.scale;
136 CHECK(buffer.type == ArrayDataType::kFloat);
137 const auto& float_buffer =
138 static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
139 auto* quantized_buffer = new Buffer<A>;
140 quantized_buffer->data.resize(float_buffer.data.size());
141 for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
142 const float src_val = float_buffer.data[i];
143 double scaled_val; // Astonishingly, using 'float' degrades accuracy just
144 // enough to make a few tests fail!
145 if (quantization_params.scale == 0) {
146 CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
147 << "so all its values should be 0.";
148 scaled_val = quantization_params.zero_point;
149 } else {
150 scaled_val = quantization_params.zero_point + inverse_scale * src_val;
151 }
152 auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
153 // In addition to its effect on the choice of quantization params upstream
154 // of here, narrow_range also means nudge the min quantized value by +1,
155 // so e.g. uint8 values get constrained to [1, 255].
156 if (integer_val == std::numeric_limits<DataType<A>>::min() &&
157 array.narrow_range) {
158 integer_val++;
159 }
160 quantized_buffer->data[i] = integer_val;
161 }
162 return std::unique_ptr<GenericBuffer>(quantized_buffer);
163 }
164
165 template <ArrayDataType A>
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,const QuantizationParams & quantization_params)166 void QuantizeArray(GraphTransformation* transformation, Model* model,
167 const std::string& name,
168 const QuantizationParams& quantization_params) {
169 auto& array = model->GetArray(name);
170 CHECK(array.data_type == ArrayDataType::kFloat);
171 CHECK(!array.quantization_params);
172 array.GetOrCreateQuantizationParams() = quantization_params;
173 if (array.buffer) {
174 array.buffer = QuantizeBuffer<A>(array, quantization_params);
175 }
176 array.data_type = A;
177 array.final_data_type = A;
178 transformation->AddMessageF(
179 "Quantized array %s to %s zero_point=%g, scale=%g", name,
180 ArrayDataTypeName(array.data_type), quantization_params.zero_point,
181 quantization_params.scale);
182 }
183
184 } // namespace
185
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params)186 void QuantizeArray(GraphTransformation* transformation, Model* model,
187 const std::string& name, ArrayDataType quantized_data_type,
188 const QuantizationParams& quantization_params) {
189 ArrayDataType adjusted_data_type = quantized_data_type;
190 auto& array = model->GetArray(name);
191 if (array.final_data_type == ArrayDataType::kInt16) {
192 adjusted_data_type = array.final_data_type;
193 }
194
195 switch (adjusted_data_type) {
196 case ArrayDataType::kUint8:
197 return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
198 quantization_params);
199 case ArrayDataType::kInt16:
200 return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
201 quantization_params);
202 case ArrayDataType::kInt32:
203 return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
204 quantization_params);
205 default:
206 LOG(FATAL) << "Unhandled case.";
207 }
208 }
209
IsArrayQuantizedRangeSubset(GraphTransformation * transformation,const Array & array,double clamp_min,double clamp_max)210 bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
211 const Array& array, double clamp_min,
212 double clamp_max) {
213 ArrayDataType quantized_data_type =
214 GetQuantizedDataType(array, array.data_type);
215 if (quantized_data_type == ArrayDataType::kNone ||
216 quantized_data_type == ArrayDataType::kFloat) {
217 // The array is not (or never will be) quantized.
218 return false;
219 }
220
221 QuantizationParams quantization_params;
222 if (!array.quantization_params) {
223 if (!array.minmax) {
224 transformation->AddMessageF("No quantization params and no minmax");
225 return false;
226 } else {
227 // Work around cases where we are asking for this prior to the Quantize
228 // transformation having added the quantization_params.
229 ChooseQuantizationParamsForArrayAndQuantizedDataType(
230 array, quantized_data_type, &quantization_params);
231 transformation->AddMessageF(
232 "No quantization params - inferring from data type %s with minmax "
233 "%g,%g as zero_point=%g, scale=%g",
234 ArrayDataTypeName(quantized_data_type), array.minmax->min,
235 array.minmax->max, quantization_params.zero_point,
236 quantization_params.scale);
237 }
238 } else {
239 quantization_params = array.GetQuantizationParams();
240 }
241
242 double quantized_min, quantized_max;
243 CHECK(GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
244 &quantized_max))
245 << "Type is not quantized";
246
247 bool has_nontrivial_min_bound = false;
248 bool has_nontrivial_max_bound = false;
249
250 double lowest_representable_output =
251 (quantized_min - quantization_params.zero_point) *
252 quantization_params.scale;
253 if (lowest_representable_output < clamp_min) {
254 has_nontrivial_min_bound = true;
255 transformation->AddMessageF(
256 "Quantized activation function is not trivial: "
257 "the lowest representable output value %g"
258 " less than the clamp min bound %g.",
259 lowest_representable_output, clamp_min);
260 }
261
262 double highest_representable_output =
263 (quantized_max - quantization_params.zero_point) *
264 quantization_params.scale;
265 if (highest_representable_output > clamp_max) {
266 has_nontrivial_max_bound = true;
267 transformation->AddMessageF(
268 "Quantized activation function is not trivial: "
269 "the highest representable output value %g"
270 " is greater than the clamp max bound %g.",
271 highest_representable_output, clamp_max);
272 }
273
274 return !has_nontrivial_min_bound && !has_nontrivial_max_bound;
275 }
276
277 } // namespace toco
278