1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
24 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/model_flags.pb.h"
27 #include "tensorflow/lite/toco/tooling_util.h"
28 #include "tensorflow/core/platform/logging.h"
29
30 namespace toco {
31
32 namespace {
33
SupportsQuantization(const Operator & op)34 bool SupportsQuantization(const Operator& op) {
35 auto type = op.type;
36 if (type == OperatorType::kUnsupported) {
37 auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
38 return unsupported->quantized;
39 }
40 return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
41 type == OperatorType::kFullyConnected ||
42 type == OperatorType::kConcatenation ||
43 type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
44 type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
45 type == OperatorType::kMinimum || type == OperatorType::kMaximum ||
46 type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
47 type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
48 type == OperatorType::kResizeBilinear ||
49 type == OperatorType::kSplit || type == OperatorType::kSub ||
50 type == OperatorType::kSqueeze || type == OperatorType::kPad ||
51 type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
52 type == OperatorType::kTanh || type == OperatorType::kMul ||
53 type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum ||
54 type == OperatorType::kSpaceToBatchND ||
55 type == OperatorType::kSpaceToDepth ||
56 type == OperatorType::kStridedSlice ||
57 type == OperatorType::kDepthToSpace ||
58 type == OperatorType::kLstmCell || type == OperatorType::kGather ||
59 type == OperatorType::kTranspose || type == OperatorType::kMean ||
60 type == OperatorType::kEqual || type == OperatorType::kGreater ||
61 type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
62 type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
63 type == OperatorType::kArgMax || type == OperatorType::kRelu ||
64 type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
65 type == OperatorType::kShape || type == OperatorType::kExpandDims ||
66 type == OperatorType::kPack || type == OperatorType::kTopK_V2 ||
67 type == OperatorType::kRandomUniform ||
68 type == OperatorType::kResizeNearestNeighbor ||
69 type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
70 type == OperatorType::kReduceMin;
71 }
72
73 // The quantized op allows output arrays of type float using
74 // the attribute support_output_type_float_in_quantized_op
SupportOutputTypeFloatInQuantizedOp(const Operator & op)75 bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) {
76 auto type = op.type;
77 if (type == OperatorType::kUnsupported) {
78 auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
79 return unsupported->support_output_type_float_in_quantized_op;
80 }
81 return false;
82 }
GetOrComputeMinMax(Model * model,const string & array_name)83 const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
84 auto& array = model->GetArray(array_name);
85 // Normally we should have a MinMax recorded on this Array,
86 // so we just use it.
87 if (array.minmax != nullptr) {
88 return *array.minmax;
89 }
90
91 // We don't have a MinMax. That's bad news: we need
92 // the graph to provide MinMax info for all arrays in order
93 // for inference to reproduce faithfully the same quantization
94 // error as the training process had.
95 //
96 // But we still want to support a fallback for constant arrays,
97 // just using the plain min and max computed from array elements.
98 // We should hopefully never rely on that in production, as that
99 // will not give very good accuracy as that typically won't be
100 // exactly what the training process used. But it will be useful
101 // to allow easily trying out quantization even if the graph
102 // lacks some minmax information.
103 if (array.buffer != nullptr) {
104 CHECK(array.buffer->type == ArrayDataType::kFloat);
105 const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
106 // We always want [min, max] to contain 0.
107 float min = 0.f;
108 float max = 0.f;
109 for (const auto& val : data) {
110 min = std::min(min, val);
111 max = std::max(max, val);
112 }
113 if (min == 0.f && max == 0.f) {
114 // Prevent downstream anger from quantized math that expects min and max
115 // to not be equal.
116 max = 1.f;
117 }
118 // No need to warn about accuracy if all array values are equal to either
119 // min or max:
120 // in that case, quantization is exact, and such arrays are not learned
121 // weights arrays for which fake-quantization would make sense, rather
122 // they tend to be hardcoded arrays of zeros or ones used in some graphs.
123 bool is_quantization_trivially_exact = true;
124 for (const auto& val : data) {
125 is_quantization_trivially_exact &= (val == min || val == max);
126 }
127 if (!is_quantization_trivially_exact) {
128 LOG(WARNING)
129 << "Constant array " << array_name
130 << " lacks MinMax information. To make up for that, we will now "
131 "compute"
132 << " the MinMax from actual array elements. That will result in"
133 << " quantization parameters that probably do not match whichever "
134 "arithmetic"
135 << " was used during training, and thus will probably be a cause of "
136 "poor"
137 << " inference accuracy.";
138 }
139 auto& minmax = array.GetOrCreateMinMax();
140 minmax.min = min;
141 minmax.max = max;
142 return minmax;
143 }
144
145 LOG(FATAL) << "Array " << array_name
146 << " does not have MinMax information, "
147 "and is not a constant array. Cannot "
148 "proceed with quantization.";
149 }
150
151 struct QuantizationPoints {
152 int64 min_value;
153 int64 max_value;
154 int64 central_value;
155 };
156
157 template <ArrayDataType A>
GetQuantizationPoints()158 QuantizationPoints GetQuantizationPoints() {
159 QuantizationPoints qp;
160 using Integer = DataType<A>;
161 qp.min_value = std::numeric_limits<Integer>::min();
162 qp.max_value = std::numeric_limits<Integer>::max();
163 // eg [-128,127]...
164 qp.central_value = (qp.min_value / 2 + // -128 -> -64.
165 (qp.max_value - 1) / 2 + // 127 -> 63.
166 1);
167 return qp;
168 }
169
GetQuantizationPoints(ArrayDataType data_type)170 QuantizationPoints GetQuantizationPoints(ArrayDataType data_type) {
171 switch (data_type) {
172 case ArrayDataType::kUint8:
173 return GetQuantizationPoints<ArrayDataType::kUint8>();
174 case ArrayDataType::kInt16:
175 return GetQuantizationPoints<ArrayDataType::kInt16>();
176 case ArrayDataType::kInt32:
177 return GetQuantizationPoints<ArrayDataType::kInt32>();
178 default:
179 LOG(FATAL) << "Unhandled case.";
180 }
181 }
182
ChooseQuantizationForOperatorInput(GraphTransformation * transformation,Model * model,const Operator & op,std::size_t input_index,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)183 bool ChooseQuantizationForOperatorInput(
184 GraphTransformation* transformation, Model* model, const Operator& op,
185 std::size_t input_index, ArrayDataType* quantized_data_type,
186 QuantizationParams* quantization_params) {
187 const auto& input = op.inputs[input_index];
188 auto& array = model->GetArray(input);
189 if (array.data_type != ArrayDataType::kFloat) {
190 return false;
191 }
192
193 // Quantization of bias vectors
194 bool is_bias_vector = false;
195 int activations_input_index;
196 int weights_input_index;
197 if (op.type == OperatorType::kConv ||
198 op.type == OperatorType::kDepthwiseConv ||
199 op.type == OperatorType::kFullyConnected) {
200 if (input_index == 2) {
201 is_bias_vector = true;
202 activations_input_index = 0;
203 weights_input_index = 1;
204 }
205 }
206 if (op.type == OperatorType::kLstmCell) {
207 if (input_index == LstmCellOperator::BIASES_INPUT) {
208 is_bias_vector = true;
209 activations_input_index = LstmCellOperator::DATA_INPUT;
210 weights_input_index = LstmCellOperator::WEIGHTS_INPUT;
211 }
212 }
213 if (is_bias_vector) {
214 // Quantization of bias vector.
215 // We need both of the mandatory inputs (input activations and weights) to
216 // have been already quantized.
217 const auto& input_activations =
218 model->GetArray(op.inputs[activations_input_index]);
219 const auto& input_weights = model->GetArray(op.inputs[weights_input_index]);
220 if (!input_activations.quantization_params ||
221 !input_weights.quantization_params) {
222 transformation->AddMessageF(
223 "Input array %s is a bias vector but has no qparams", input);
224 return false;
225 }
226 const auto input_activations_scale =
227 input_activations.quantization_params->scale;
228 const auto input_weights_scale = input_weights.quantization_params->scale;
229 quantization_params->scale = input_activations_scale * input_weights_scale;
230 quantization_params->zero_point = 0;
231 *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kInt32);
232 transformation->AddMessageF(
233 "Input array %s is a bias vector. Choosing quantization params "
234 "accordingly.",
235 input);
236 return true;
237 }
238
239 const MinMax& minmax = GetOrComputeMinMax(model, input);
240
241 if (op.type == OperatorType::kLstmCell) {
242 if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
243 *quantized_data_type = ArrayDataType::kInt16;
244 ChooseQuantizationParamsForArrayAndQuantizedDataType(
245 array, *quantized_data_type, quantization_params);
246 return true;
247 }
248 }
249
250 *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
251 ChooseQuantizationParamsForArrayAndQuantizedDataType(
252 array, *quantized_data_type, quantization_params);
253 transformation->AddMessageF(
254 "For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) "
255 "with zero_point=%d, scale=%g",
256 input, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
257 ArrayDataTypeName(array.final_data_type), quantization_params->zero_point,
258 quantization_params->scale);
259 return true;
260 }
261
IsExactlyRepresentable(double real_value,ArrayDataType data_type,const QuantizationParams & quantization_params)262 bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
263 const QuantizationParams& quantization_params) {
264 const double scaled_value =
265 quantization_params.zero_point + real_value / quantization_params.scale;
266 const double fractional_scaled_value =
267 scaled_value - std::round(scaled_value);
268 if (std::abs(fractional_scaled_value) > 1e-12) {
269 return false;
270 }
271 const double rounded_scaled_value = std::round(scaled_value);
272 if (data_type == ArrayDataType::kUint8) {
273 if (rounded_scaled_value < 0 || rounded_scaled_value > 255) {
274 return false;
275 }
276 }
277 return true;
278 }
279
280 // Quantized data type is preset to the type of the input before this function.
ChooseHardcodedQuantizationForOperatorOutput(const Operator & op,const Array & array,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)281 bool ChooseHardcodedQuantizationForOperatorOutput(
282 const Operator& op, const Array& array, ArrayDataType* quantized_data_type,
283 QuantizationParams* quantization_params) {
284 if (op.type == OperatorType::kL2Normalization) {
285 // L2Normalization has range: [-1, 1].
286 // 0 should be exactly representable, as values will typically be centered
287 // around 0, with many values near 0.
288 *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
289 const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
290 quantization_params->zero_point = qp.central_value;
291 quantization_params->scale = 1. / (qp.central_value - qp.min_value);
292 CHECK(
293 IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
294 return true;
295 }
296 if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) {
297 // Logistic and Softmax have range: [0, 1].
298 //
299 // For Logistic, 0.5 should be exactly representable, as implementations
300 // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
301 // the glueing of the two halves of the graph will only be seamless if we
302 // are accurately representing logistic(0) == 0.5.
303 *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
304 const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
305 quantization_params->zero_point = 0;
306 quantization_params->scale = 1. / (qp.max_value + 1);
307 CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
308 *quantization_params));
309 return true;
310 }
311 if (op.type == OperatorType::kLogSoftmax) {
312 // LogSoftmax has range: [LogSoftmaxOperator::kOutputRangeMin, 0].
313 *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
314 const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
315 quantization_params->zero_point = qp.max_value;
316 quantization_params->scale =
317 -LogSoftmaxOperator::kOutputRangeMin / (qp.max_value + 1);
318 // While not strictly necessary, it is easier to interpret output data and
319 // quantization if the scale is similar to others (such as power of 2).
320 CHECK(IsExactlyRepresentable(LogSoftmaxOperator::kOutputRangeMin / 2,
321 *quantized_data_type, *quantization_params));
322 return true;
323 }
324 if (op.type == OperatorType::kTanh) {
325 // Tanh has the range: [-1, 1].
326 *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
327 const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
328 quantization_params->zero_point = qp.central_value;
329 quantization_params->scale = 1. / (qp.central_value - qp.min_value);
330 // 0 should be exactly representable, as values will typically be centered
331 // around 0, with many values near 0.
332 CHECK(
333 IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
334 return true;
335 }
336 return false;
337 }
338
ChooseQuantizationForOperatorOutput(GraphTransformation * transformation,Model * model,const Operator & op,std::size_t output_index,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)339 bool ChooseQuantizationForOperatorOutput(
340 GraphTransformation* transformation, Model* model, const Operator& op,
341 std::size_t output_index, ArrayDataType* quantized_data_type,
342 QuantizationParams* quantization_params) {
343 const auto& output = op.outputs[output_index];
344 auto& array = model->GetArray(output);
345 if (array.data_type != ArrayDataType::kFloat) {
346 transformation->AddMessageF("Array data type already set to %s, final=%s",
347 ArrayDataTypeName(array.data_type),
348 ArrayDataTypeName(array.final_data_type));
349 return false;
350 }
351 *quantized_data_type = model->GetArray(op.inputs[0]).data_type;
352 if (ChooseHardcodedQuantizationForOperatorOutput(
353 op, array, quantized_data_type, quantization_params)) {
354 transformation->AddMessageF(
355 "Output array %s is produced by a %s operator. Choosing fixed "
356 "quantization params accordingly.",
357 output, OperatorTypeName(op.type));
358 return true;
359 }
360 if ((op.type == OperatorType::kConcatenation &&
361 model->flags.change_concat_input_ranges()) ||
362 op.type == OperatorType::kDepthToSpace ||
363 op.type == OperatorType::kSpaceToDepth ||
364 op.type == OperatorType::kReshape || op.type == OperatorType::kSplit ||
365 op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 ||
366 op.type == OperatorType::kRelu6 || op.type == OperatorType::kPRelu) {
367 int data_input_index = 0;
368 if (op.type == OperatorType::kSplit) {
369 data_input_index = 1;
370 }
371 // Copying and rearrangement ops should preserve the quantization parameters
372 // of the input array.
373 const auto& input_array = model->GetArray(op.inputs[data_input_index]);
374 const auto& input_quantization_params = input_array.GetQuantizationParams();
375 *quantized_data_type =
376 GetQuantizedDataType(input_array, ArrayDataType::kUint8);
377 *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
378 quantization_params->zero_point = input_quantization_params.zero_point;
379 quantization_params->scale = input_quantization_params.scale;
380
381 transformation->AddMessageF(
382 "Output array %s is produced by a %s operator. Copying quantization "
383 "params from input array.",
384 output, OperatorTypeName(op.type));
385 return true;
386 }
387 const MinMax& minmax = GetOrComputeMinMax(model, output);
388 if (op.type == OperatorType::kLstmCell) {
389 if (output_index == LstmCellOperator::STATE_OUTPUT ||
390 output_index == LstmCellOperator::ACTIV_TEMP) {
391 *quantized_data_type = ArrayDataType::kInt16;
392 ChooseQuantizationParamsForArrayAndQuantizedDataType(
393 array, *quantized_data_type, quantization_params);
394 return true;
395 }
396 }
397 *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
398 ChooseQuantizationParamsForArrayAndQuantizedDataType(
399 array, *quantized_data_type, quantization_params);
400 transformation->AddMessageF(
401 "For output array %s with min=%g, max=%g"
402 ", chose to quantize as %s with zero_point=%d"
403 ", scale=%g",
404 output, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
405 quantization_params->zero_point, quantization_params->scale);
406
407 return true;
408 }
409
410 // Fixes array minmax info to match the quantization parameters.
411 // This is required for when quantization parameters change for an array during
412 // quantization (such as ChooseQuantizationForOperatorOutput).
FixMinMaxPostQuantization(GraphTransformation * transformation,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params,MinMax * minmax)413 void FixMinMaxPostQuantization(GraphTransformation* transformation,
414 ArrayDataType quantized_data_type,
415 const QuantizationParams& quantization_params,
416 MinMax* minmax) {
417 double quantized_min, quantized_max;
418 if (!GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
419 &quantized_max)) {
420 // Not quantized - no update required.
421 return;
422 }
423
424 // Compute new minmax values.
425 double min = (quantized_min - quantization_params.zero_point) *
426 quantization_params.scale;
427 double max = (quantized_max - quantization_params.zero_point) *
428 quantization_params.scale;
429
430 // If we are close to the existing minmax values don't bother changing them.
431 // This prevents propagating small floating point precision errors.
432 constexpr double kMinMaxThreshold = 1e-5;
433 const double width = max - min;
434 if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
435 std::abs(max - minmax->max) > kMinMaxThreshold * width) {
436 transformation->AddMessageF(
437 "Adjusting min/max from %g,%g to %g,%g to match quantization params",
438 minmax->min, minmax->max, min, max);
439 minmax->min = min;
440 minmax->max = max;
441 }
442 }
443
444 } // namespace
445
Run(Model * model,std::size_t op_index,bool * modified)446 ::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index,
447 bool* modified) {
448 *modified = false;
449 // Our general "quantization" graph transformation consists in replacing
450 // QuantizedInputArrays[] ->
451 // DequantizeOperators[] ->
452 // FloatInputArrays[] ->
453 // Operator ->
454 // FloatOutputArray
455 // by
456 // QuantizedInputArrays[] ->
457 // Operator ->
458 // QuantizedOutputArray ->
459 // DequantizeOperator ->
460 // FloatOutputArray
461 //
462 // In other words, this is pushing Dequantize operators to the right of
463 // other operators.
464 //
465
466 auto& op = *model->operators[op_index];
467 if (op.type == OperatorType::kDequantize ||
468 op.type == OperatorType::kFakeQuant) {
469 return ::tensorflow::Status::OK();
470 }
471
472 // Our assumption here is that the input arrays are already quantized -
473 // that is typically the case in models operating on an input bitmap
474 // image, and MakeInitialDequantizeOp should have already resolved
475 // the handling of the input image as an initial Dequantize op.
476 //
477 // Thus we are building around the assumption that the graph always starts
478 // with a quantized input array, and only after some Dequantize op do we have
479 // float arrays. The problem of quantizing the graph thus becomes a problem of
480 // pushing Dequantize ops to the right of other ops.
481 //
482 // Let us just guard this assumption by the following assertion:
483 for (const auto& input : op.inputs) {
484 const auto& input_array = model->GetArray(input);
485 if (IsInputArray(*model, input) &&
486 input_array.data_type == ArrayDataType::kFloat) {
487 CHECK(input_array.quantization_params)
488 << "Input array " << input << " is missing quantization_params";
489 }
490 }
491 if (!SupportsQuantization(op)) {
492 return tensorflow::errors::InvalidArgument(
493 "Unimplemented: this graph contains an operator of type ",
494 HelpfulOperatorTypeName(op),
495 " for which the quantized form is not yet implemented. Sorry, and "
496 "patches welcome (that's a relatively fun patch to write, mostly "
497 "providing the actual quantized arithmetic code for this op).");
498 }
499
500 for (const auto& input : op.inputs) {
501 const auto& array = model->GetArray(input);
502 if (array.data_type == ArrayDataType::kFloat) {
503 if (!array.minmax && !array.buffer) {
504 LOG(WARNING) << "Can't quantize input array " << input
505 << " because it lacks min/max info";
506 return ::tensorflow::Status::OK();
507 }
508 const auto* other_op = GetOpWithOutput(*model, input);
509 if (other_op && other_op->type != OperatorType::kDequantize) {
510 AddMessageF(
511 "Not quantizing %s for now, because its input array %s is not "
512 "produced by a Dequantize op, "
513 "which means that we should yield and let other ops "
514 "get quantized first",
515 LogName(op), input);
516 return ::tensorflow::Status::OK();
517 }
518 }
519 }
520
521 bool changed = false;
522
523 // Quantize inputs, remove any Dequantize op on the inputs side
524 for (std::size_t input_index = 0; input_index < op.inputs.size();
525 input_index++) {
526 ArrayDataType quantized_data_type;
527 QuantizationParams quantization_params;
528 if (ChooseQuantizationForOperatorInput(this, model, op, input_index,
529 &quantized_data_type,
530 &quantization_params)) {
531 const auto& input = op.inputs[input_index];
532 if (IsConstantParameterArray(*model, input)) {
533 QuantizeArray(this, model, input, quantized_data_type,
534 quantization_params);
535 changed = true;
536 } else {
537 auto dequantize_it = FindOpWithOutput(*model, input);
538 if (dequantize_it != model->operators.end()) {
539 auto* dequantize_op = dequantize_it->get();
540 CHECK(dequantize_op->type == OperatorType::kDequantize);
541 op.inputs[input_index] = dequantize_op->inputs[0];
542 // Check if the output of that Dequantize op was not used by any
543 // other operator. We will then erase that Dequantize op.
544 if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
545 if (IsDiscardableArray(*model, dequantize_op->outputs[0])) {
546 // Usual case: we can just discard the dequantize output.
547 model->EraseArray(dequantize_op->outputs[0]);
548 } else {
549 // The dequantize output is not discardable. Special care needed.
550 // If any of the model's output_arrays was pointing to the
551 // Dequantize op's output, let it point to the Dequantize op's
552 // input instead.
553 for (int i = 0; i < model->flags.output_arrays_size(); i++) {
554 if (model->flags.output_arrays(i) ==
555 dequantize_op->outputs[0]) {
556 // TODO(b/78013785): never rename output arrays.
557 if (IsInputArray(*model, dequantize_op->inputs[0])) {
558 // The op input is an input array and the output is an
559 // output array and we can't have an array be both. Insert a
560 // copy op to ensure the two arrays stay separate.
561 AddMessageF(
562 "Tried to rename output array %d while removing "
563 "dequant "
564 "op %s but array is also an input; inserting copy %s "
565 "-> %s",
566 i, LogName(*dequantize_op),
567 model->flags.output_arrays(i),
568 dequantize_op->inputs[0]);
569 InsertCopyOperator(model, dequantize_op->inputs[0],
570 dequantize_op->outputs[0]);
571 } else {
572 // Op output is strictly used as an output array, so we can
573 // just rename the array and directly bypass the op.
574 AddMessageF(
575 "Renaming output array %d after removing dequant op "
576 "%s: "
577 "%s -> %s",
578 i, LogName(*dequantize_op),
579 model->flags.output_arrays(i),
580 dequantize_op->inputs[0]);
581 model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
582 model->EraseArray(dequantize_op->outputs[0]);
583 }
584 break;
585 }
586 }
587 }
588 model->operators.erase(dequantize_it);
589 }
590 changed = true;
591 } else {
592 // This input array is not produced by a Dequantize op.
593 // We have encountered this situation in RNN graphs, whose cyclic
594 // nature defeats the basic assumption underlying the quantization
595 // algorithm implemented here. For now, when we have seen this
596 // happening, the array in question was a RNN state array itself,
597 // so let us just implement this case here, and guard that assumption
598 // with a CHECK. A more general fix would involve revisiting the
599 // design of this whole Quantization transformation.
600 bool is_rnn_state_array = false;
601 for (const auto& rnn_state : model->flags.rnn_states()) {
602 if (rnn_state.state_array() == input) {
603 is_rnn_state_array = true;
604 break;
605 }
606 }
607 CHECK(is_rnn_state_array);
608 QuantizeArray(this, model, input, quantized_data_type,
609 quantization_params);
610 changed = true;
611 }
612 }
613 }
614 }
615
616 // Quantize outputs, add Dequantize ops as needed on the outputs side
617 if (SupportOutputTypeFloatInQuantizedOp(op)) {
618 LOG(WARNING)
619 << HelpfulOperatorTypeName(op) << " is a quantized op"
620 << "but it has a model flag that sets the output arrays to float.";
621 } else {
622 for (std::size_t output_index = 0; output_index < op.outputs.size();
623 output_index++) {
624 QuantizationParams quantization_params;
625 ArrayDataType quantized_data_type;
626 if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
627 &quantized_data_type,
628 &quantization_params)) {
629 changed = true;
630 const auto& output = op.outputs[output_index];
631 auto& output_array = model->GetArray(output);
632
633 // Fix up the min/max information on the output array to match the
634 // chosen quantization parameters.
635 CHECK(output_array.minmax)
636 << "Output array named " << output << " lacks minmax";
637 auto& output_minmax = output_array.GetMinMax();
638 FixMinMaxPostQuantization(this, quantized_data_type,
639 quantization_params, &output_minmax);
640
641 QuantizeArray(this, model, output, quantized_data_type,
642 quantization_params);
643
644 const auto& dequantized_output =
645 AvailableArrayName(*model, output + "_dequantized");
646 auto& dequantized_output_array =
647 model->GetOrCreateArray(dequantized_output);
648 dequantized_output_array.data_type = ArrayDataType::kFloat;
649 dequantized_output_array.final_data_type = output_array.data_type;
650 auto& dequantized_output_minmax =
651 dequantized_output_array.GetOrCreateMinMax();
652 dequantized_output_minmax.min = output_minmax.min;
653 dequantized_output_minmax.max = output_minmax.max;
654 for (const auto& other_op : model->operators) {
655 for (auto& other_op_input : other_op->inputs) {
656 if (other_op_input == output) {
657 other_op_input = dequantized_output;
658 }
659 }
660 }
661 auto* dequantize_op = new DequantizeOperator;
662 dequantize_op->inputs = {output};
663 dequantize_op->outputs = {dequantized_output};
664 for (int i = 0; i < model->flags.output_arrays_size(); i++) {
665 if (model->flags.output_arrays(i) == output) {
666 // TODO(b/78013785): never rename output arrays.
667 AddMessageF(
668 "Renaming output array %d after inserting dequant op %s: %s -> "
669 "%s",
670 i, LogName(*dequantize_op), model->flags.output_arrays(i),
671 dequantized_output);
672 model->flags.set_output_arrays(i, dequantized_output);
673 }
674 }
675 const auto op_it = FindOp(*model, &op);
676 model->operators.emplace(op_it + 1, dequantize_op);
677 }
678 }
679 }
680
681 *modified = changed;
682 return ::tensorflow::Status::OK();
683 }
684
685 } // namespace toco
686