• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 
17 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
18 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
19 #include "tensorflow/lite/toco/model.h"
20 #include "tensorflow/lite/toco/tooling_util.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace toco {
24 
InferQuantizedDataTypeFromFakeQuant(const FakeQuantOperator & op,ArrayDataType * out_quantized_data_type)25 bool InferQuantizedDataTypeFromFakeQuant(
26     const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
27   if (op.num_bits <= 8) {
28     *out_quantized_data_type = ArrayDataType::kUint8;
29     return true;
30   } else if (op.num_bits <= 16) {
31     *out_quantized_data_type = ArrayDataType::kInt16;
32     return true;
33   } else {
34     *out_quantized_data_type = ArrayDataType::kNone;
35     return false;
36   }
37 }
38 
GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,double * out_min_value,double * out_max_value)39 bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
40                                         double* out_min_value,
41                                         double* out_max_value) {
42   switch (data_type) {
43     case ArrayDataType::kUint8:
44       *out_min_value = 0;
45       *out_max_value = 255;
46       return true;
47     case ArrayDataType::kInt16:
48       *out_min_value = -32768;
49       *out_max_value = 32767;
50       return true;
51     default:
52       return false;
53   }
54 }
55 
GetQuantizedDataType(const Array & array,ArrayDataType default_type)56 ArrayDataType GetQuantizedDataType(const Array& array,
57                                    ArrayDataType default_type) {
58   switch (array.final_data_type) {
59     case ArrayDataType::kInt8:
60     case ArrayDataType::kUint8:
61     case ArrayDataType::kInt16:
62     case ArrayDataType::kUint16:
63     case ArrayDataType::kInt32:
64     case ArrayDataType::kUint32:
65     case ArrayDataType::kInt64:
66     case ArrayDataType::kUint64:
67       return array.final_data_type;
68     case ArrayDataType::kFloat:
69     case ArrayDataType::kNone:
70       return default_type;
71     default:
72       LOG(FATAL) << "Unhandled final quantization type "
73                  << static_cast<int>(array.final_data_type);
74   }
75 }
76 
77 template <ArrayDataType A>
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,QuantizationParams * quantization_params)78 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
79     const Array& array, QuantizationParams* quantization_params) {
80   *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
81       array.minmax->min, array.minmax->max, array.narrow_range);
82 }
83 
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,ArrayDataType quantized_data_type,QuantizationParams * quantization_params)84 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
85     const Array& array, ArrayDataType quantized_data_type,
86     QuantizationParams* quantization_params) {
87   switch (quantized_data_type) {
88     case ArrayDataType::kInt8:
89       ChooseQuantizationParamsForArrayAndQuantizedDataType<
90           ArrayDataType::kInt8>(array, quantization_params);
91       break;
92     case ArrayDataType::kUint8:
93       ChooseQuantizationParamsForArrayAndQuantizedDataType<
94           ArrayDataType::kUint8>(array, quantization_params);
95       break;
96     case ArrayDataType::kInt16:
97       ChooseQuantizationParamsForArrayAndQuantizedDataType<
98           ArrayDataType::kInt16>(array, quantization_params);
99       break;
100     case ArrayDataType::kUint16:
101       ChooseQuantizationParamsForArrayAndQuantizedDataType<
102           ArrayDataType::kUint16>(array, quantization_params);
103       break;
104     case ArrayDataType::kInt32:
105       ChooseQuantizationParamsForArrayAndQuantizedDataType<
106           ArrayDataType::kInt32>(array, quantization_params);
107       break;
108     case ArrayDataType::kUint32:
109       ChooseQuantizationParamsForArrayAndQuantizedDataType<
110           ArrayDataType::kUint32>(array, quantization_params);
111       break;
112     case ArrayDataType::kInt64:
113       ChooseQuantizationParamsForArrayAndQuantizedDataType<
114           ArrayDataType::kInt64>(array, quantization_params);
115       break;
116     case ArrayDataType::kUint64:
117       ChooseQuantizationParamsForArrayAndQuantizedDataType<
118           ArrayDataType::kUint64>(array, quantization_params);
119       break;
120     case ArrayDataType::kFloat:
121     case ArrayDataType::kComplex64:
122     case ArrayDataType::kNone:
123     default:
124       LOG(FATAL) << "Unhandled final quantization type "
125                  << static_cast<int>(quantized_data_type);
126   }
127 }
128 
129 namespace {
130 
131 template <ArrayDataType A>
QuantizeBuffer(const Array & array,const QuantizationParams & quantization_params)132 std::unique_ptr<GenericBuffer> QuantizeBuffer(
133     const Array& array, const QuantizationParams& quantization_params) {
134   const GenericBuffer& buffer = *array.buffer;
135   const auto inverse_scale = 1. / quantization_params.scale;
136   CHECK(buffer.type == ArrayDataType::kFloat);
137   const auto& float_buffer =
138       static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
139   auto* quantized_buffer = new Buffer<A>;
140   quantized_buffer->data.resize(float_buffer.data.size());
141   for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
142     const float src_val = float_buffer.data[i];
143     double scaled_val;  // Astonishingly, using 'float' degrades accuracy just
144                         // enough to make a few tests fail!
145     if (quantization_params.scale == 0) {
146       CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
147                            << "so all its values should be 0.";
148       scaled_val = quantization_params.zero_point;
149     } else {
150       scaled_val = quantization_params.zero_point + inverse_scale * src_val;
151     }
152     auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
153     // In addition to its effect on the choice of quantization params upstream
154     // of here, narrow_range also means nudge the min quantized value by +1,
155     // so e.g. uint8 values get constrained to [1, 255].
156     if (integer_val == std::numeric_limits<DataType<A>>::min() &&
157         array.narrow_range) {
158       integer_val++;
159     }
160     quantized_buffer->data[i] = integer_val;
161   }
162   return std::unique_ptr<GenericBuffer>(quantized_buffer);
163 }
164 
165 template <ArrayDataType A>
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,const QuantizationParams & quantization_params)166 void QuantizeArray(GraphTransformation* transformation, Model* model,
167                    const std::string& name,
168                    const QuantizationParams& quantization_params) {
169   auto& array = model->GetArray(name);
170   CHECK(array.data_type == ArrayDataType::kFloat);
171   CHECK(!array.quantization_params);
172   array.GetOrCreateQuantizationParams() = quantization_params;
173   if (array.buffer) {
174     array.buffer = QuantizeBuffer<A>(array, quantization_params);
175   }
176   array.data_type = A;
177   array.final_data_type = A;
178   transformation->AddMessageF(
179       "Quantized array %s to %s zero_point=%g, scale=%g", name,
180       ArrayDataTypeName(array.data_type), quantization_params.zero_point,
181       quantization_params.scale);
182 }
183 
184 }  // namespace
185 
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params)186 void QuantizeArray(GraphTransformation* transformation, Model* model,
187                    const std::string& name, ArrayDataType quantized_data_type,
188                    const QuantizationParams& quantization_params) {
189   ArrayDataType adjusted_data_type = quantized_data_type;
190   auto& array = model->GetArray(name);
191   if (array.final_data_type == ArrayDataType::kInt16) {
192     adjusted_data_type = array.final_data_type;
193   }
194 
195   switch (adjusted_data_type) {
196     case ArrayDataType::kUint8:
197       return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
198                                                   quantization_params);
199     case ArrayDataType::kInt16:
200       return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
201                                                   quantization_params);
202     case ArrayDataType::kInt32:
203       return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
204                                                   quantization_params);
205     default:
206       LOG(FATAL) << "Unhandled case.";
207   }
208 }
209 
IsArrayQuantizedRangeSubset(GraphTransformation * transformation,const Array & array,double clamp_min,double clamp_max)210 bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
211                                  const Array& array, double clamp_min,
212                                  double clamp_max) {
213   ArrayDataType quantized_data_type =
214       GetQuantizedDataType(array, array.data_type);
215   if (quantized_data_type == ArrayDataType::kNone ||
216       quantized_data_type == ArrayDataType::kFloat) {
217     // The array is not (or never will be) quantized.
218     return false;
219   }
220 
221   QuantizationParams quantization_params;
222   if (!array.quantization_params) {
223     if (!array.minmax) {
224       transformation->AddMessageF("No quantization params and no minmax");
225       return false;
226     } else {
227       // Work around cases where we are asking for this prior to the Quantize
228       // transformation having added the quantization_params.
229       ChooseQuantizationParamsForArrayAndQuantizedDataType(
230           array, quantized_data_type, &quantization_params);
231       transformation->AddMessageF(
232           "No quantization params - inferring from data type %s with minmax "
233           "%g,%g as zero_point=%g, scale=%g",
234           ArrayDataTypeName(quantized_data_type), array.minmax->min,
235           array.minmax->max, quantization_params.zero_point,
236           quantization_params.scale);
237     }
238   } else {
239     quantization_params = array.GetQuantizationParams();
240   }
241 
242   double quantized_min, quantized_max;
243   CHECK(GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
244                                            &quantized_max))
245       << "Type is not quantized";
246 
247   bool has_nontrivial_min_bound = false;
248   bool has_nontrivial_max_bound = false;
249 
250   double lowest_representable_output =
251       (quantized_min - quantization_params.zero_point) *
252       quantization_params.scale;
253   if (lowest_representable_output < clamp_min) {
254     has_nontrivial_min_bound = true;
255     transformation->AddMessageF(
256         "Quantized activation function is not trivial: "
257         "the lowest representable output value %g"
258         " less than the clamp min bound %g.",
259         lowest_representable_output, clamp_min);
260   }
261 
262   double highest_representable_output =
263       (quantized_max - quantization_params.zero_point) *
264       quantization_params.scale;
265   if (highest_representable_output > clamp_max) {
266     has_nontrivial_max_bound = true;
267     transformation->AddMessageF(
268         "Quantized activation function is not trivial: "
269         "the highest representable output value %g"
270         " is greater than the clamp max bound %g.",
271         highest_representable_output, clamp_max);
272   }
273 
274   return !has_nontrivial_min_bound && !has_nontrivial_max_bound;
275 }
276 
277 }  // namespace toco
278