• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
24 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/model_flags.pb.h"
27 #include "tensorflow/lite/toco/tooling_util.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace toco {
31 
32 namespace {
33 
SupportsQuantization(const Operator & op)34 bool SupportsQuantization(const Operator& op) {
35   auto type = op.type;
36   if (type == OperatorType::kUnsupported) {
37     auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
38     return unsupported->quantized;
39   }
40   return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
41          type == OperatorType::kFullyConnected ||
42          type == OperatorType::kConcatenation ||
43          type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
44          type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
45          type == OperatorType::kMinimum || type == OperatorType::kMaximum ||
46          type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
47          type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
48          type == OperatorType::kResizeBilinear ||
49          type == OperatorType::kSplit || type == OperatorType::kSub ||
50          type == OperatorType::kSqueeze || type == OperatorType::kPad ||
51          type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
52          type == OperatorType::kTanh || type == OperatorType::kMul ||
53          type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum ||
54          type == OperatorType::kSpaceToBatchND ||
55          type == OperatorType::kSpaceToDepth ||
56          type == OperatorType::kStridedSlice ||
57          type == OperatorType::kDepthToSpace ||
58          type == OperatorType::kLstmCell || type == OperatorType::kGather ||
59          type == OperatorType::kTranspose || type == OperatorType::kMean ||
60          type == OperatorType::kEqual || type == OperatorType::kGreater ||
61          type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
62          type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
63          type == OperatorType::kArgMax || type == OperatorType::kRelu ||
64          type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
65          type == OperatorType::kShape || type == OperatorType::kExpandDims ||
66          type == OperatorType::kPack || type == OperatorType::kTopK_V2 ||
67          type == OperatorType::kRandomUniform ||
68          type == OperatorType::kResizeNearestNeighbor ||
69          type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
70          type == OperatorType::kReduceMin;
71 }
72 
73 // The quantized op allows output arrays of type float using
74 // the attribute support_output_type_float_in_quantized_op
SupportOutputTypeFloatInQuantizedOp(const Operator & op)75 bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) {
76   auto type = op.type;
77   if (type == OperatorType::kUnsupported) {
78     auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
79     return unsupported->support_output_type_float_in_quantized_op;
80   }
81   return false;
82 }
GetOrComputeMinMax(Model * model,const string & array_name)83 const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
84   auto& array = model->GetArray(array_name);
85   // Normally we should have a MinMax recorded on this Array,
86   // so we just use it.
87   if (array.minmax != nullptr) {
88     return *array.minmax;
89   }
90 
91   // We don't have a MinMax. That's bad news: we need
92   // the graph to provide MinMax info for all arrays in order
93   // for inference to reproduce faithfully the same quantization
94   // error as the training process had.
95   //
96   // But we still want to support a fallback for constant arrays,
97   // just using the plain min and max computed from array elements.
98   // We should hopefully never rely on that in production, as that
99   // will not give very good accuracy as that typically won't be
100   // exactly what the training process used. But it will be useful
101   // to allow easily trying out quantization even if the graph
102   // lacks some minmax information.
103   if (array.buffer != nullptr) {
104     CHECK(array.buffer->type == ArrayDataType::kFloat);
105     const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
106     // We always want [min, max] to contain 0.
107     float min = 0.f;
108     float max = 0.f;
109     for (const auto& val : data) {
110       min = std::min(min, val);
111       max = std::max(max, val);
112     }
113     if (min == 0.f && max == 0.f) {
114       // Prevent downstream anger from quantized math that expects min and max
115       // to not be equal.
116       max = 1.f;
117     }
118     // No need to warn about accuracy if all array values are equal to either
119     // min or max:
120     // in that case, quantization is exact, and such arrays are not learned
121     // weights arrays for which fake-quantization would make sense, rather
122     // they tend to be hardcoded arrays of zeros or ones used in some graphs.
123     bool is_quantization_trivially_exact = true;
124     for (const auto& val : data) {
125       is_quantization_trivially_exact &= (val == min || val == max);
126     }
127     if (!is_quantization_trivially_exact) {
128       LOG(WARNING)
129           << "Constant array " << array_name
130           << " lacks MinMax information. To make up for that, we will now "
131              "compute"
132           << " the MinMax from actual array elements. That will result in"
133           << " quantization parameters that probably do not match whichever "
134              "arithmetic"
135           << " was used during training, and thus will probably be a cause of "
136              "poor"
137           << " inference accuracy.";
138     }
139     auto& minmax = array.GetOrCreateMinMax();
140     minmax.min = min;
141     minmax.max = max;
142     return minmax;
143   }
144 
145   LOG(FATAL) << "Array " << array_name
146              << " does not have MinMax information, "
147                 "and is not a constant array. Cannot "
148                 "proceed with quantization.";
149 }
150 
151 struct QuantizationPoints {
152   int64 min_value;
153   int64 max_value;
154   int64 central_value;
155 };
156 
157 template <ArrayDataType A>
GetQuantizationPoints()158 QuantizationPoints GetQuantizationPoints() {
159   QuantizationPoints qp;
160   using Integer = DataType<A>;
161   qp.min_value = std::numeric_limits<Integer>::min();
162   qp.max_value = std::numeric_limits<Integer>::max();
163   // eg [-128,127]...
164   qp.central_value = (qp.min_value / 2 +        // -128 -> -64.
165                       (qp.max_value - 1) / 2 +  // 127 -> 63.
166                       1);
167   return qp;
168 }
169 
GetQuantizationPoints(ArrayDataType data_type)170 QuantizationPoints GetQuantizationPoints(ArrayDataType data_type) {
171   switch (data_type) {
172     case ArrayDataType::kUint8:
173       return GetQuantizationPoints<ArrayDataType::kUint8>();
174     case ArrayDataType::kInt16:
175       return GetQuantizationPoints<ArrayDataType::kInt16>();
176     case ArrayDataType::kInt32:
177       return GetQuantizationPoints<ArrayDataType::kInt32>();
178     default:
179       LOG(FATAL) << "Unhandled case.";
180   }
181 }
182 
ChooseQuantizationForOperatorInput(GraphTransformation * transformation,Model * model,const Operator & op,std::size_t input_index,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)183 bool ChooseQuantizationForOperatorInput(
184     GraphTransformation* transformation, Model* model, const Operator& op,
185     std::size_t input_index, ArrayDataType* quantized_data_type,
186     QuantizationParams* quantization_params) {
187   const auto& input = op.inputs[input_index];
188   auto& array = model->GetArray(input);
189   if (array.data_type != ArrayDataType::kFloat) {
190     return false;
191   }
192 
193   // Quantization of bias vectors
194   bool is_bias_vector = false;
195   int activations_input_index;
196   int weights_input_index;
197   if (op.type == OperatorType::kConv ||
198       op.type == OperatorType::kDepthwiseConv ||
199       op.type == OperatorType::kFullyConnected) {
200     if (input_index == 2) {
201       is_bias_vector = true;
202       activations_input_index = 0;
203       weights_input_index = 1;
204     }
205   }
206   if (op.type == OperatorType::kLstmCell) {
207     if (input_index == LstmCellOperator::BIASES_INPUT) {
208       is_bias_vector = true;
209       activations_input_index = LstmCellOperator::DATA_INPUT;
210       weights_input_index = LstmCellOperator::WEIGHTS_INPUT;
211     }
212   }
213   if (is_bias_vector) {
214     // Quantization of bias vector.
215     // We need both of the mandatory inputs (input activations and weights) to
216     // have been already quantized.
217     const auto& input_activations =
218         model->GetArray(op.inputs[activations_input_index]);
219     const auto& input_weights = model->GetArray(op.inputs[weights_input_index]);
220     if (!input_activations.quantization_params ||
221         !input_weights.quantization_params) {
222       transformation->AddMessageF(
223           "Input array %s is a bias vector but has no qparams", input);
224       return false;
225     }
226     const auto input_activations_scale =
227         input_activations.quantization_params->scale;
228     const auto input_weights_scale = input_weights.quantization_params->scale;
229     quantization_params->scale = input_activations_scale * input_weights_scale;
230     quantization_params->zero_point = 0;
231     *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kInt32);
232     transformation->AddMessageF(
233         "Input array %s is a bias vector. Choosing quantization params "
234         "accordingly.",
235         input);
236     return true;
237   }
238 
239   const MinMax& minmax = GetOrComputeMinMax(model, input);
240 
241   if (op.type == OperatorType::kLstmCell) {
242     if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
243       *quantized_data_type = ArrayDataType::kInt16;
244       ChooseQuantizationParamsForArrayAndQuantizedDataType(
245           array, *quantized_data_type, quantization_params);
246       return true;
247     }
248   }
249 
250   *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
251   ChooseQuantizationParamsForArrayAndQuantizedDataType(
252       array, *quantized_data_type, quantization_params);
253   transformation->AddMessageF(
254       "For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) "
255       "with zero_point=%d, scale=%g",
256       input, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
257       ArrayDataTypeName(array.final_data_type), quantization_params->zero_point,
258       quantization_params->scale);
259   return true;
260 }
261 
IsExactlyRepresentable(double real_value,ArrayDataType data_type,const QuantizationParams & quantization_params)262 bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
263                             const QuantizationParams& quantization_params) {
264   const double scaled_value =
265       quantization_params.zero_point + real_value / quantization_params.scale;
266   const double fractional_scaled_value =
267       scaled_value - std::round(scaled_value);
268   if (std::abs(fractional_scaled_value) > 1e-12) {
269     return false;
270   }
271   const double rounded_scaled_value = std::round(scaled_value);
272   if (data_type == ArrayDataType::kUint8) {
273     if (rounded_scaled_value < 0 || rounded_scaled_value > 255) {
274       return false;
275     }
276   }
277   return true;
278 }
279 
280 // Quantized data type is preset to the type of the input before this function.
ChooseHardcodedQuantizationForOperatorOutput(const Operator & op,const Array & array,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)281 bool ChooseHardcodedQuantizationForOperatorOutput(
282     const Operator& op, const Array& array, ArrayDataType* quantized_data_type,
283     QuantizationParams* quantization_params) {
284   if (op.type == OperatorType::kL2Normalization) {
285     // L2Normalization has range: [-1, 1].
286     // 0 should be exactly representable, as values will typically be centered
287     // around 0, with many values near 0.
288     *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
289     const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
290     quantization_params->zero_point = qp.central_value;
291     quantization_params->scale = 1. / (qp.central_value - qp.min_value);
292     CHECK(
293         IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
294     return true;
295   }
296   if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) {
297     // Logistic and Softmax have range: [0, 1].
298     //
299     // For Logistic, 0.5 should be exactly representable, as implementations
300     // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
301     // the glueing of the two halves of the graph will only be seamless if we
302     // are accurately representing logistic(0) == 0.5.
303     *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
304     const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
305     quantization_params->zero_point = 0;
306     quantization_params->scale = 1. / (qp.max_value + 1);
307     CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
308                                  *quantization_params));
309     return true;
310   }
311   if (op.type == OperatorType::kLogSoftmax) {
312     // LogSoftmax has range: [LogSoftmaxOperator::kOutputRangeMin, 0].
313     *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
314     const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
315     quantization_params->zero_point = qp.max_value;
316     quantization_params->scale =
317         -LogSoftmaxOperator::kOutputRangeMin / (qp.max_value + 1);
318     // While not strictly necessary, it is easier to interpret output data and
319     // quantization if the scale is similar to others (such as power of 2).
320     CHECK(IsExactlyRepresentable(LogSoftmaxOperator::kOutputRangeMin / 2,
321                                  *quantized_data_type, *quantization_params));
322     return true;
323   }
324   if (op.type == OperatorType::kTanh) {
325     // Tanh has the range: [-1, 1].
326     *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
327     const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
328     quantization_params->zero_point = qp.central_value;
329     quantization_params->scale = 1. / (qp.central_value - qp.min_value);
330     // 0 should be exactly representable, as values will typically be centered
331     // around 0, with many values near 0.
332     CHECK(
333         IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
334     return true;
335   }
336   return false;
337 }
338 
ChooseQuantizationForOperatorOutput(GraphTransformation * transformation,Model * model,const Operator & op,std::size_t output_index,ArrayDataType * quantized_data_type,QuantizationParams * quantization_params)339 bool ChooseQuantizationForOperatorOutput(
340     GraphTransformation* transformation, Model* model, const Operator& op,
341     std::size_t output_index, ArrayDataType* quantized_data_type,
342     QuantizationParams* quantization_params) {
343   const auto& output = op.outputs[output_index];
344   auto& array = model->GetArray(output);
345   if (array.data_type != ArrayDataType::kFloat) {
346     transformation->AddMessageF("Array data type already set to %s, final=%s",
347                                 ArrayDataTypeName(array.data_type),
348                                 ArrayDataTypeName(array.final_data_type));
349     return false;
350   }
351   *quantized_data_type = model->GetArray(op.inputs[0]).data_type;
352   if (ChooseHardcodedQuantizationForOperatorOutput(
353           op, array, quantized_data_type, quantization_params)) {
354     transformation->AddMessageF(
355         "Output array %s is produced by a %s operator. Choosing fixed "
356         "quantization params accordingly.",
357         output, OperatorTypeName(op.type));
358     return true;
359   }
360   if ((op.type == OperatorType::kConcatenation &&
361        model->flags.change_concat_input_ranges()) ||
362       op.type == OperatorType::kDepthToSpace ||
363       op.type == OperatorType::kSpaceToDepth ||
364       op.type == OperatorType::kReshape || op.type == OperatorType::kSplit ||
365       op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 ||
366       op.type == OperatorType::kRelu6 || op.type == OperatorType::kPRelu) {
367     int data_input_index = 0;
368     if (op.type == OperatorType::kSplit) {
369       data_input_index = 1;
370     }
371     // Copying and rearrangement ops should preserve the quantization parameters
372     // of the input array.
373     const auto& input_array = model->GetArray(op.inputs[data_input_index]);
374     const auto& input_quantization_params = input_array.GetQuantizationParams();
375     *quantized_data_type =
376         GetQuantizedDataType(input_array, ArrayDataType::kUint8);
377     *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
378     quantization_params->zero_point = input_quantization_params.zero_point;
379     quantization_params->scale = input_quantization_params.scale;
380 
381     transformation->AddMessageF(
382         "Output array %s is produced by a %s operator. Copying quantization "
383         "params from input array.",
384         output, OperatorTypeName(op.type));
385     return true;
386   }
387   const MinMax& minmax = GetOrComputeMinMax(model, output);
388   if (op.type == OperatorType::kLstmCell) {
389     if (output_index == LstmCellOperator::STATE_OUTPUT ||
390         output_index == LstmCellOperator::ACTIV_TEMP) {
391       *quantized_data_type = ArrayDataType::kInt16;
392       ChooseQuantizationParamsForArrayAndQuantizedDataType(
393           array, *quantized_data_type, quantization_params);
394       return true;
395     }
396   }
397   *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
398   ChooseQuantizationParamsForArrayAndQuantizedDataType(
399       array, *quantized_data_type, quantization_params);
400   transformation->AddMessageF(
401       "For output array %s with min=%g, max=%g"
402       ", chose to quantize as %s with zero_point=%d"
403       ", scale=%g",
404       output, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
405       quantization_params->zero_point, quantization_params->scale);
406 
407   return true;
408 }
409 
410 // Fixes array minmax info to match the quantization parameters.
411 // This is required for when quantization parameters change for an array during
412 // quantization (such as ChooseQuantizationForOperatorOutput).
FixMinMaxPostQuantization(GraphTransformation * transformation,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params,MinMax * minmax)413 void FixMinMaxPostQuantization(GraphTransformation* transformation,
414                                ArrayDataType quantized_data_type,
415                                const QuantizationParams& quantization_params,
416                                MinMax* minmax) {
417   double quantized_min, quantized_max;
418   if (!GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
419                                           &quantized_max)) {
420     // Not quantized - no update required.
421     return;
422   }
423 
424   // Compute new minmax values.
425   double min = (quantized_min - quantization_params.zero_point) *
426                quantization_params.scale;
427   double max = (quantized_max - quantization_params.zero_point) *
428                quantization_params.scale;
429 
430   // If we are close to the existing minmax values don't bother changing them.
431   // This prevents propagating small floating point precision errors.
432   constexpr double kMinMaxThreshold = 1e-5;
433   const double width = max - min;
434   if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
435       std::abs(max - minmax->max) > kMinMaxThreshold * width) {
436     transformation->AddMessageF(
437         "Adjusting min/max from %g,%g to %g,%g to match quantization params",
438         minmax->min, minmax->max, min, max);
439     minmax->min = min;
440     minmax->max = max;
441   }
442 }
443 
444 }  // namespace
445 
Run(Model * model,std::size_t op_index,bool * modified)446 ::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index,
447                                    bool* modified) {
448   *modified = false;
449   // Our general "quantization" graph transformation consists in replacing
450   //   QuantizedInputArrays[] ->
451   //     DequantizeOperators[] ->
452   //       FloatInputArrays[] ->
453   //         Operator ->
454   //           FloatOutputArray
455   // by
456   //   QuantizedInputArrays[] ->
457   //     Operator ->
458   //       QuantizedOutputArray ->
459   //         DequantizeOperator ->
460   //           FloatOutputArray
461   //
462   // In other words, this is pushing Dequantize operators to the right of
463   // other operators.
464   //
465 
466   auto& op = *model->operators[op_index];
467   if (op.type == OperatorType::kDequantize ||
468       op.type == OperatorType::kFakeQuant) {
469     return ::tensorflow::Status::OK();
470   }
471 
472   // Our assumption here is that the input arrays are already quantized -
473   // that is typically the case in models operating on an input bitmap
474   // image, and MakeInitialDequantizeOp should have already resolved
475   // the handling of the input image as an initial Dequantize op.
476   //
477   // Thus we are building around the assumption that the graph always starts
478   // with a quantized input array, and only after some Dequantize op do we have
479   // float arrays. The problem of quantizing the graph thus becomes a problem of
480   // pushing Dequantize ops to the right of other ops.
481   //
482   // Let us just guard this assumption by the following assertion:
483   for (const auto& input : op.inputs) {
484     const auto& input_array = model->GetArray(input);
485     if (IsInputArray(*model, input) &&
486         input_array.data_type == ArrayDataType::kFloat) {
487       CHECK(input_array.quantization_params)
488           << "Input array " << input << " is missing quantization_params";
489     }
490   }
491   if (!SupportsQuantization(op)) {
492     return tensorflow::errors::InvalidArgument(
493         "Unimplemented: this graph contains an operator of type ",
494         HelpfulOperatorTypeName(op),
495         " for which the quantized form is not yet implemented. Sorry, and "
496         "patches welcome (that's a relatively fun patch to write, mostly "
497         "providing the actual quantized arithmetic code for this op).");
498   }
499 
500   for (const auto& input : op.inputs) {
501     const auto& array = model->GetArray(input);
502     if (array.data_type == ArrayDataType::kFloat) {
503       if (!array.minmax && !array.buffer) {
504         LOG(WARNING) << "Can't quantize input array " << input
505                      << " because it lacks min/max info";
506         return ::tensorflow::Status::OK();
507       }
508       const auto* other_op = GetOpWithOutput(*model, input);
509       if (other_op && other_op->type != OperatorType::kDequantize) {
510         AddMessageF(
511             "Not quantizing %s for now, because its input array %s is not "
512             "produced by a Dequantize op, "
513             "which means that we should yield and let other ops "
514             "get quantized first",
515             LogName(op), input);
516         return ::tensorflow::Status::OK();
517       }
518     }
519   }
520 
521   bool changed = false;
522 
523   // Quantize inputs, remove any Dequantize op on the inputs side
524   for (std::size_t input_index = 0; input_index < op.inputs.size();
525        input_index++) {
526     ArrayDataType quantized_data_type;
527     QuantizationParams quantization_params;
528     if (ChooseQuantizationForOperatorInput(this, model, op, input_index,
529                                            &quantized_data_type,
530                                            &quantization_params)) {
531       const auto& input = op.inputs[input_index];
532       if (IsConstantParameterArray(*model, input)) {
533         QuantizeArray(this, model, input, quantized_data_type,
534                       quantization_params);
535         changed = true;
536       } else {
537         auto dequantize_it = FindOpWithOutput(*model, input);
538         if (dequantize_it != model->operators.end()) {
539           auto* dequantize_op = dequantize_it->get();
540           CHECK(dequantize_op->type == OperatorType::kDequantize);
541           op.inputs[input_index] = dequantize_op->inputs[0];
542           // Check if the output of that Dequantize op was not used by any
543           // other operator. We will then erase that Dequantize op.
544           if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
545             if (IsDiscardableArray(*model, dequantize_op->outputs[0])) {
546               // Usual case: we can just discard the dequantize output.
547               model->EraseArray(dequantize_op->outputs[0]);
548             } else {
549               // The dequantize output is not discardable. Special care needed.
550               // If any of the model's output_arrays was pointing to the
551               // Dequantize op's output, let it point to the Dequantize op's
552               // input instead.
553               for (int i = 0; i < model->flags.output_arrays_size(); i++) {
554                 if (model->flags.output_arrays(i) ==
555                     dequantize_op->outputs[0]) {
556                   // TODO(b/78013785): never rename output arrays.
557                   if (IsInputArray(*model, dequantize_op->inputs[0])) {
558                     // The op input is an input array and the output is an
559                     // output array and we can't have an array be both. Insert a
560                     // copy op to ensure the two arrays stay separate.
561                     AddMessageF(
562                         "Tried to rename output array %d while removing "
563                         "dequant "
564                         "op %s but array is also an input; inserting copy %s "
565                         "-> %s",
566                         i, LogName(*dequantize_op),
567                         model->flags.output_arrays(i),
568                         dequantize_op->inputs[0]);
569                     InsertCopyOperator(model, dequantize_op->inputs[0],
570                                        dequantize_op->outputs[0]);
571                   } else {
572                     // Op output is strictly used as an output array, so we can
573                     // just rename the array and directly bypass the op.
574                     AddMessageF(
575                         "Renaming output array %d after removing dequant op "
576                         "%s: "
577                         "%s -> %s",
578                         i, LogName(*dequantize_op),
579                         model->flags.output_arrays(i),
580                         dequantize_op->inputs[0]);
581                     model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
582                     model->EraseArray(dequantize_op->outputs[0]);
583                   }
584                   break;
585                 }
586               }
587             }
588             model->operators.erase(dequantize_it);
589           }
590           changed = true;
591         } else {
592           // This input array is not produced by a Dequantize op.
593           // We have encountered this situation in RNN graphs, whose cyclic
594           // nature defeats the basic assumption underlying the quantization
595           // algorithm implemented here. For now, when we have seen this
596           // happening, the array in question was a RNN state array itself,
597           // so let us just implement this case here, and guard that assumption
598           // with a CHECK. A more general fix would involve revisiting the
599           // design of this whole Quantization transformation.
600           bool is_rnn_state_array = false;
601           for (const auto& rnn_state : model->flags.rnn_states()) {
602             if (rnn_state.state_array() == input) {
603               is_rnn_state_array = true;
604               break;
605             }
606           }
607           CHECK(is_rnn_state_array);
608           QuantizeArray(this, model, input, quantized_data_type,
609                         quantization_params);
610           changed = true;
611         }
612       }
613     }
614   }
615 
616   // Quantize outputs, add Dequantize ops as needed on the outputs side
617   if (SupportOutputTypeFloatInQuantizedOp(op)) {
618     LOG(WARNING)
619         << HelpfulOperatorTypeName(op) << " is a quantized op"
620         << "but it has a model flag that sets the output arrays to float.";
621   } else {
622     for (std::size_t output_index = 0; output_index < op.outputs.size();
623          output_index++) {
624       QuantizationParams quantization_params;
625       ArrayDataType quantized_data_type;
626       if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
627                                               &quantized_data_type,
628                                               &quantization_params)) {
629         changed = true;
630         const auto& output = op.outputs[output_index];
631         auto& output_array = model->GetArray(output);
632 
633         // Fix up the min/max information on the output array to match the
634         // chosen quantization parameters.
635         CHECK(output_array.minmax)
636             << "Output array named " << output << " lacks minmax";
637         auto& output_minmax = output_array.GetMinMax();
638         FixMinMaxPostQuantization(this, quantized_data_type,
639                                   quantization_params, &output_minmax);
640 
641         QuantizeArray(this, model, output, quantized_data_type,
642                       quantization_params);
643 
644         const auto& dequantized_output =
645             AvailableArrayName(*model, output + "_dequantized");
646         auto& dequantized_output_array =
647             model->GetOrCreateArray(dequantized_output);
648         dequantized_output_array.data_type = ArrayDataType::kFloat;
649         dequantized_output_array.final_data_type = output_array.data_type;
650         auto& dequantized_output_minmax =
651             dequantized_output_array.GetOrCreateMinMax();
652         dequantized_output_minmax.min = output_minmax.min;
653         dequantized_output_minmax.max = output_minmax.max;
654         for (const auto& other_op : model->operators) {
655           for (auto& other_op_input : other_op->inputs) {
656             if (other_op_input == output) {
657               other_op_input = dequantized_output;
658             }
659           }
660         }
661         auto* dequantize_op = new DequantizeOperator;
662         dequantize_op->inputs = {output};
663         dequantize_op->outputs = {dequantized_output};
664         for (int i = 0; i < model->flags.output_arrays_size(); i++) {
665           if (model->flags.output_arrays(i) == output) {
666             // TODO(b/78013785): never rename output arrays.
667             AddMessageF(
668                 "Renaming output array %d after inserting dequant op %s: %s -> "
669                 "%s",
670                 i, LogName(*dequantize_op), model->flags.output_arrays(i),
671                 dequantized_output);
672             model->flags.set_output_arrays(i, dequantized_output);
673           }
674         }
675         const auto op_it = FindOp(*model, &op);
676         model->operators.emplace(op_it + 1, dequantize_op);
677       }
678     }
679   }
680 
681   *modified = changed;
682   return ::tensorflow::Status::OK();
683 }
684 
685 }  // namespace toco
686