1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24
25 namespace toco {
26
27 namespace {
28
HardcodeMinMaxForIm2colArray(Model * model,Operator * op)29 bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
30 if (op->outputs.size() != 2) {
31 return false;
32 }
33 auto& im2col_array = model->GetArray(op->outputs[1]);
34 if (im2col_array.minmax) {
35 return false;
36 }
37 const auto& input_array = model->GetArray(op->inputs[0]);
38 if (!input_array.minmax) {
39 return false;
40 }
41 const auto& input_minmax = input_array.GetMinMax();
42 CHECK(!im2col_array.minmax);
43 auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
44 im2col_minmax.min = input_minmax.min;
45 im2col_minmax.max = input_minmax.max;
46 return true;
47 }
48
HardcodeMinMaxForL2Normalization(Model * model,Operator * op)49 bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
50 auto& output_array = model->GetArray(op->outputs[0]);
51 if (output_array.minmax) {
52 return false;
53 }
54 const auto& input_array = model->GetArray(op->inputs[0]);
55 if (!input_array.minmax) {
56 return false;
57 }
58 const auto& input_minmax = input_array.GetMinMax();
59 CHECK(!output_array.minmax);
60 auto& output_minmax = output_array.GetOrCreateMinMax();
61 output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
62 output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
63 return true;
64 }
65
HardcodeInputMinMaxFromOutput(Model * model,Operator * op)66 bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
67 auto& input = model->GetArray(op->inputs[0]);
68 if (input.minmax) {
69 const auto* minmax = input.minmax.get();
70 if (minmax) {
71 return false;
72 }
73 }
74 auto& output = model->GetArray(op->outputs[0]);
75 if (output.minmax) {
76 const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
77 if (minmax) {
78 input.GetOrCreateMinMax() = *minmax;
79 return true;
80 }
81 }
82 return false;
83 }
84
HardcodeMinMaxForConcatenation(Model * model,Operator * op)85 bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
86 // Do not early return if the output already has min/max:
87 // we may still need to adjust the inputs min/max.
88 bool has_minmax = false;
89 double overall_min = std::numeric_limits<double>::infinity();
90 double overall_max = -std::numeric_limits<double>::infinity();
91 for (const auto& input : op->inputs) {
92 if (model->GetArray(input).minmax) {
93 has_minmax = true;
94 const auto* minmax = model->GetArray(input).minmax.get();
95 if (minmax) {
96 overall_min = std::min(overall_min, minmax->min);
97 overall_max = std::max(overall_max, minmax->max);
98 }
99 }
100 }
101 auto& output = model->GetArray(op->outputs[0]);
102 if (output.minmax) {
103 has_minmax = true;
104 const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
105 if (minmax) {
106 overall_min = std::min(overall_min, minmax->min);
107 overall_max = std::max(overall_max, minmax->max);
108 }
109 }
110 if (!has_minmax) {
111 return false;
112 }
113 MinMax overall_minmax;
114 overall_minmax.min = overall_min;
115 overall_minmax.max = overall_max;
116 bool changed = false;
117 if (model->flags.change_concat_input_ranges()) {
118 for (const auto& input : op->inputs) {
119 auto& array = model->GetArray(input);
120 if (!array.minmax) {
121 changed = true;
122 } else if (!(overall_minmax == array.GetMinMax())) {
123 changed = true;
124 LOG(WARNING)
125 << "Tweaking the MinMax of array " << input << ", which is "
126 << "an input to " << LogName(*op) << ", because we want all inputs "
127 << "and outputs of a Concatenation operator to have the same "
128 << "MinMax so that it can be implemented as a pure byte-copy, no "
129 "arithmetic.";
130 }
131 array.GetOrCreateMinMax() = overall_minmax;
132 }
133 }
134 if (!output.minmax) {
135 changed = true;
136 } else if (!(overall_minmax == output.GetMinMax())) {
137 if (model->flags.change_concat_input_ranges()) {
138 changed = true;
139 LOG(WARNING)
140 << "Tweaking the MinMax of the output array of " << LogName(*op)
141 << ", because we want all inputs "
142 << "and outputs of a Concatenation operator to have the same MinMax "
143 << "so that it can be implemented as a pure byte-copy, no "
144 << "arithmetic.";
145 } else {
146 return false;
147 }
148 }
149 output.GetOrCreateMinMax() = overall_minmax;
150
151 return changed;
152 }
153
HardcodeMinMaxForSplit(Model * model,Operator * op)154 bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
155 // Data is in second input.
156 auto& input_array = model->GetArray(op->inputs[1]);
157 if (!input_array.minmax) {
158 return false;
159 }
160 bool changed = false;
161 for (const auto& output : op->outputs) {
162 auto& array = model->GetArray(output);
163 if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) {
164 changed = true;
165 array.GetOrCreateMinMax() = *input_array.minmax;
166 }
167 }
168 return changed;
169 }
170
171 // The output of average or max pooling is within the same range as its input.
HardcodeMinMaxForAverageOrMaxPool(Model * model,Operator * op)172 bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
173 auto& output_array = model->GetArray(op->outputs[0]);
174 if (output_array.minmax) {
175 return false;
176 }
177 const auto& input_array = model->GetArray(op->inputs[0]);
178 if (!input_array.minmax) {
179 return false;
180 }
181 const auto& input_minmax = input_array.GetMinMax();
182 CHECK(!output_array.minmax);
183 auto& output_minmax = output_array.GetOrCreateMinMax();
184 output_minmax.min = std::min(input_minmax.min, 0.);
185 output_minmax.max = std::max(input_minmax.max, 0.);
186 return true;
187 }
188
HardcodeMinMaxFromFirstInput(Model * model,Operator * op)189 bool HardcodeMinMaxFromFirstInput(Model* model, Operator* op) {
190 auto& output_array = model->GetArray(op->outputs[0]);
191 if (output_array.minmax) {
192 return false;
193 }
194 const auto& input_array = model->GetArray(op->inputs[0]);
195 if (!input_array.minmax) {
196 return false;
197 }
198 const auto& input_minmax = input_array.GetMinMax();
199 CHECK(!output_array.minmax);
200 auto& output_minmax = output_array.GetOrCreateMinMax();
201 output_minmax.min = input_minmax.min;
202 output_minmax.max = input_minmax.max;
203 return true;
204 }
205
HardcodeMinMaxForSelect(Model * model,Operator * op)206 bool HardcodeMinMaxForSelect(Model* model, Operator* op) {
207 auto& output_array = model->GetArray(op->outputs[0]);
208 if (output_array.minmax) {
209 return false;
210 }
211
212 auto& input_array_1 = model->GetArray(op->inputs[1]);
213 auto& input_array_2 = model->GetArray(op->inputs[2]);
214
215 if (!input_array_1.minmax && !input_array_2.minmax) {
216 return false;
217 }
218
219 // Propagate up if one input is quantized and the other is constant.
220 if (!input_array_1.minmax &&
221 IsConstantParameterArray(*model, op->inputs[1])) {
222 auto& minmax_1 = input_array_1.GetOrCreateMinMax();
223 const auto& minmax_2 = input_array_2.GetMinMax();
224 minmax_1.min = minmax_2.min;
225 minmax_1.max = minmax_2.max;
226 }
227
228 if (!input_array_2.minmax &&
229 IsConstantParameterArray(*model, op->inputs[2])) {
230 auto& minmax_2 = input_array_2.GetOrCreateMinMax();
231 const auto& minmax_1 = input_array_1.GetMinMax();
232 minmax_2.min = minmax_1.min;
233 minmax_2.max = minmax_1.max;
234 }
235
236 if (!input_array_1.minmax || !input_array_2.minmax) {
237 return false;
238 }
239
240 const auto& input_minmax_1 = input_array_1.GetMinMax();
241 const auto& input_minmax_2 = input_array_2.GetMinMax();
242
243 CHECK_EQ(input_minmax_1.min, input_minmax_2.min);
244 CHECK_EQ(input_minmax_1.max, input_minmax_2.max);
245 CHECK(!output_array.minmax);
246 auto& output_minmax = output_array.GetOrCreateMinMax();
247 output_minmax.min = input_minmax_1.min;
248 output_minmax.max = input_minmax_1.max;
249 return true;
250 }
251
HardcodeMinMaxForOutput(Model * model,Operator * op,double min,double max)252 bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
253 double max) {
254 CHECK_EQ(op->outputs.size(), 1);
255 auto& output_array = model->GetArray(op->outputs[0]);
256 if (output_array.minmax) {
257 return false;
258 }
259 const auto& input_array = model->GetArray(op->inputs[0]);
260 if (!input_array.minmax) {
261 return false;
262 }
263 CHECK(!output_array.minmax);
264 auto& output_minmax = output_array.GetOrCreateMinMax();
265 output_minmax.min = min;
266 output_minmax.max = max;
267 return true;
268 }
269
MinMaxApproximatelyEqual(const MinMax & minmax1,const MinMax & minmax2)270 bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
271 const double magnitude =
272 std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
273 const double tolerated = 1e-6 * magnitude;
274 return std::abs(minmax1.min - minmax2.min) <= tolerated &&
275 std::abs(minmax1.max - minmax2.max) <= tolerated;
276 }
277
278 // Propagates MinMax from any of the listed arrays, to all others.
279 // If multiple of these arrays have MinMax, then these are required
280 // to agree with each other.
PropagateMinMaxAmongArrays(Model * model,const std::vector<std::string> & array_names)281 bool PropagateMinMaxAmongArrays(Model* model,
282 const std::vector<std::string>& array_names) {
283 std::string reference_array_name;
284 MinMax* reference_minmax = nullptr;
285 for (const std::string& array_name : array_names) {
286 if (model->GetArray(array_name).minmax) {
287 reference_array_name = array_name;
288 reference_minmax = model->GetArray(array_name).minmax.get();
289 break;
290 }
291 }
292 // No MinMax info is available to propagate.
293 if (!reference_minmax) {
294 return false;
295 }
296 bool changed = false;
297 for (const std::string& array_name : array_names) {
298 auto& array = model->GetArray(array_name);
299 if (array.minmax) {
300 CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
301 << "Both the following arrays have minmax, and they disagree: "
302 << reference_array_name << " (" << reference_minmax->min << ","
303 << reference_minmax->max << ") and " << array_name << " ("
304 << array.minmax->min << "," << array.minmax->max
305 << "). Expected that either only one of them would have minmax, or "
306 "at "
307 "least that they would agree.";
308 } else {
309 array.GetOrCreateMinMax() = *reference_minmax;
310 changed = true;
311 }
312 }
313 return changed;
314 }
315
HardcodeMinMaxForReshape(Model * model,Operator * op)316 bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
317 Array& input = model->GetArray(op->inputs[0]);
318 Array& output = model->GetArray(op->outputs[0]);
319
320 // If input and output both exist or do not exist, do nothing.
321 if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) {
322 return false;
323 }
324
325 // Otherwise propagate info amongst the input and output array.
326 return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]});
327 }
328
HardcodeMinMaxForLstmCell(Model * model,Operator * op)329 bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
330 CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
331 CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
332
333 bool changed = false;
334 changed |= PropagateMinMaxAmongArrays(
335 model, {op->inputs[LstmCellOperator::PREV_STATE_INPUT],
336 op->outputs[LstmCellOperator::STATE_OUTPUT]});
337
338 auto& input_activations =
339 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
340 if (!input_activations.minmax) {
341 auto& minmax = input_activations.GetOrCreateMinMax();
342 minmax.min = -1;
343 minmax.max = 127. / 128.;
344 changed = true;
345 }
346
347 auto& prev_output_activations =
348 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
349 if (!prev_output_activations.minmax) {
350 auto& minmax = prev_output_activations.GetOrCreateMinMax();
351 minmax.min = -1;
352 minmax.max = 127. / 128.;
353 changed = true;
354 }
355
356 auto& output_concat_temp =
357 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]);
358 if (!output_concat_temp.minmax) {
359 auto& minmax = output_concat_temp.GetOrCreateMinMax();
360 minmax.min = -1;
361 minmax.max = 127. / 128.;
362 changed = true;
363 }
364
365 auto& output_activations =
366 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]);
367 if (!output_activations.minmax) {
368 auto& minmax = output_activations.GetOrCreateMinMax();
369 minmax.min = -1;
370 minmax.max = 127. / 128.;
371 changed = true;
372 }
373
374 // (This comment should morph into proper documentation for
375 // quantization of LSTM models. It isn't just a local implementation detail,
376 // the training code for LSTM models needs to be adjusted to that.)
377 //
378 // Finally, output_activations_temp holds the output of the fully-connected
379 // node inside the LSTM cell. For it, we hardcode a minmax of [-8, 8].
380 // The rationale for that is given in a lengthy comment on the LstmCell
381 // quantized runtime implementation in reference_ops.h.
382 auto& output_activations_temp =
383 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]);
384 if (!output_activations_temp.minmax) {
385 auto& minmax = output_activations_temp.GetOrCreateMinMax();
386 minmax.min = -8;
387 minmax.max = 8 * 32767. / 32768.;
388 changed = true;
389 }
390
391 return changed;
392 }
393
HardcodeMinMaxForPack(Model * model,Operator * op)394 bool HardcodeMinMaxForPack(Model* model, Operator* op) {
395 auto& output_array = model->GetArray(op->outputs[0]);
396 if (output_array.minmax) {
397 return false;
398 }
399
400 // If all tensors being packed have the same min/max range, hardcode min/max
401 // for the output.
402 const auto& first_input_array = model->GetArray(op->inputs[0]);
403 if (!first_input_array.minmax) {
404 return false;
405 }
406 const auto& first_input_minmax = first_input_array.GetMinMax();
407
408 for (size_t i = 1; i < op->inputs.size(); i++) {
409 const auto& input_array = model->GetArray(op->inputs[i]);
410 if (!input_array.minmax) {
411 return false;
412 }
413 if (first_input_minmax != input_array.GetMinMax()) {
414 return false;
415 }
416 }
417
418 auto& output_minmax = output_array.GetOrCreateMinMax();
419 output_minmax.min = first_input_minmax.min;
420 output_minmax.max = first_input_minmax.max;
421 return true;
422 }
423
424 } // namespace
425
Run(Model * model,std::size_t op_index,bool * modified)426 ::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
427 bool* modified) {
428 *modified = false;
429 auto it = model->operators.begin() + op_index;
430 auto* op = it->get();
431 bool changed = false;
432 switch (op->type) {
433 case OperatorType::kConv:
434 changed = HardcodeMinMaxForIm2colArray(model, op);
435 break;
436
437 case OperatorType::kL2Normalization:
438 changed = HardcodeMinMaxForL2Normalization(model, op);
439 break;
440
441 case OperatorType::kRelu:
442 // For any normalization other than batch norm, the quantizations ranges
443 // before and after relu are expected to be known. Having a quantization
444 // op before relu would reduce the number of bits of precision for the
445 // activation in half. So we deduce the range before relu from that after
446 // the relu. This would eliminate the need for two fake quantization nodes
447 // and would not reduce the bits of precision available for activation.
448 changed = HardcodeInputMinMaxFromOutput(model, op);
449 break;
450
451 case OperatorType::kConcatenation:
452 changed = HardcodeMinMaxForConcatenation(model, op);
453 break;
454
455 case OperatorType::kSplit:
456 changed = HardcodeMinMaxForSplit(model, op);
457 break;
458
459 case OperatorType::kAveragePool:
460 case OperatorType::kMaxPool:
461 changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
462 break;
463
464 case OperatorType::kResizeBilinear:
465 case OperatorType::kResizeNearestNeighbor:
466 case OperatorType::kSlice:
467 case OperatorType::kStridedSlice:
468 case OperatorType::kSqueeze:
469 case OperatorType::kExpandDims:
470 case OperatorType::kPad:
471 case OperatorType::kGather:
472 case OperatorType::kTranspose:
473 case OperatorType::kMean:
474 case OperatorType::kReduceMax:
475 case OperatorType::kReduceMin:
476 changed = HardcodeMinMaxFromFirstInput(model, op);
477 break;
478 case OperatorType::kPack:
479 changed = HardcodeMinMaxForPack(model, op);
480 break;
481 case OperatorType::kSum:
482 // reduce_sum is expected to change the output range. Hence
483 // a fake_quant op is necessary in the output to minimize error. However
484 // in special circumstances like when computing expected value using
485 // reduce_sum the input range and the output range matches. Hence the
486 // below code would act as a fallback. If a fake_quant node is observed in
487 // the output that takes precedence over the hard coding logic below.
488 changed = HardcodeMinMaxFromFirstInput(model, op);
489 if (changed) {
490 LOG(WARNING) << "Using the input range for output in reduce_sum op."
491 << "This could have an impact on your model accuracy.";
492 }
493 break;
494 case OperatorType::kSelect:
495 changed = HardcodeMinMaxForSelect(model, op);
496 break;
497 case OperatorType::kLogistic:
498 // We hardcode quantization_params to: zero_point=0, scale=1/256.
499 // This choice of minmax is the one that is equivalent to that.
500 changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
501 break;
502
503 case OperatorType::kSoftmax:
504 // We hardcode quantization_params to: zero_point=0, scale=1/256.
505 // This choice of minmax is the one that is equivalent to that.
506 changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
507 break;
508
509 case OperatorType::kTanh:
510 // We hardcode quantization_params to: zero_point=127, scale=1/128.
511 // This choice of minmax is the one that is equivalent to that.
512 changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0);
513 break;
514
515 case OperatorType::kLstmCell:
516 changed = HardcodeMinMaxForLstmCell(model, op);
517 break;
518
519 case OperatorType::kReshape:
520 changed = HardcodeMinMaxForReshape(model, op);
521 break;
522
523 default:
524 break;
525 }
526 if (changed) {
527 AddMessageF("Hardcoded min-max through %s", LogName(*op));
528 }
529 *modified = changed;
530 return ::tensorflow::Status::OK();
531 }
532
533 } // namespace toco
534