• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 
25 namespace toco {
26 
27 namespace {
28 
HardcodeMinMaxForIm2colArray(Model * model,Operator * op)29 bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
30   if (op->outputs.size() != 2) {
31     return false;
32   }
33   auto& im2col_array = model->GetArray(op->outputs[1]);
34   if (im2col_array.minmax) {
35     return false;
36   }
37   const auto& input_array = model->GetArray(op->inputs[0]);
38   if (!input_array.minmax) {
39     return false;
40   }
41   const auto& input_minmax = input_array.GetMinMax();
42   CHECK(!im2col_array.minmax);
43   auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
44   im2col_minmax.min = input_minmax.min;
45   im2col_minmax.max = input_minmax.max;
46   return true;
47 }
48 
HardcodeMinMaxForL2Normalization(Model * model,Operator * op)49 bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
50   auto& output_array = model->GetArray(op->outputs[0]);
51   if (output_array.minmax) {
52     return false;
53   }
54   const auto& input_array = model->GetArray(op->inputs[0]);
55   if (!input_array.minmax) {
56     return false;
57   }
58   const auto& input_minmax = input_array.GetMinMax();
59   CHECK(!output_array.minmax);
60   auto& output_minmax = output_array.GetOrCreateMinMax();
61   output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
62   output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
63   return true;
64 }
65 
HardcodeInputMinMaxFromOutput(Model * model,Operator * op)66 bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
67   auto& input = model->GetArray(op->inputs[0]);
68   if (input.minmax) {
69     const auto* minmax = input.minmax.get();
70     if (minmax) {
71       return false;
72     }
73   }
74   auto& output = model->GetArray(op->outputs[0]);
75   if (output.minmax) {
76     const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
77     if (minmax) {
78       input.GetOrCreateMinMax() = *minmax;
79       return true;
80     }
81   }
82   return false;
83 }
84 
HardcodeMinMaxForConcatenation(Model * model,Operator * op)85 bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
86   // Do not early return if the output already has min/max:
87   // we may still need to adjust the inputs min/max.
88   bool has_minmax = false;
89   double overall_min = std::numeric_limits<double>::infinity();
90   double overall_max = -std::numeric_limits<double>::infinity();
91   for (const auto& input : op->inputs) {
92     if (model->GetArray(input).minmax) {
93       has_minmax = true;
94       const auto* minmax = model->GetArray(input).minmax.get();
95       if (minmax) {
96         overall_min = std::min(overall_min, minmax->min);
97         overall_max = std::max(overall_max, minmax->max);
98       }
99     }
100   }
101   auto& output = model->GetArray(op->outputs[0]);
102   if (output.minmax) {
103     has_minmax = true;
104     const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
105     if (minmax) {
106       overall_min = std::min(overall_min, minmax->min);
107       overall_max = std::max(overall_max, minmax->max);
108     }
109   }
110   if (!has_minmax) {
111     return false;
112   }
113   MinMax overall_minmax;
114   overall_minmax.min = overall_min;
115   overall_minmax.max = overall_max;
116   bool changed = false;
117   if (model->flags.change_concat_input_ranges()) {
118     for (const auto& input : op->inputs) {
119       auto& array = model->GetArray(input);
120       if (!array.minmax) {
121         changed = true;
122       } else if (!(overall_minmax == array.GetMinMax())) {
123         changed = true;
124         LOG(WARNING)
125             << "Tweaking the MinMax of array " << input << ", which is "
126             << "an input to " << LogName(*op) << ", because we want all inputs "
127             << "and outputs of a Concatenation operator to have the same "
128             << "MinMax so that it can be implemented as a pure byte-copy, no "
129                "arithmetic.";
130       }
131       array.GetOrCreateMinMax() = overall_minmax;
132     }
133   }
134   if (!output.minmax) {
135     changed = true;
136   } else if (!(overall_minmax == output.GetMinMax())) {
137     if (model->flags.change_concat_input_ranges()) {
138       changed = true;
139       LOG(WARNING)
140           << "Tweaking the MinMax of the output array of " << LogName(*op)
141           << ", because we want all inputs "
142           << "and outputs of a Concatenation operator to have the same MinMax "
143           << "so that it can be implemented as a pure byte-copy, no "
144           << "arithmetic.";
145     } else {
146       return false;
147     }
148   }
149   output.GetOrCreateMinMax() = overall_minmax;
150 
151   return changed;
152 }
153 
HardcodeMinMaxForSplit(Model * model,Operator * op)154 bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
155   // Data is in second input.
156   auto& input_array = model->GetArray(op->inputs[1]);
157   if (!input_array.minmax) {
158     return false;
159   }
160   bool changed = false;
161   for (const auto& output : op->outputs) {
162     auto& array = model->GetArray(output);
163     if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) {
164       changed = true;
165       array.GetOrCreateMinMax() = *input_array.minmax;
166     }
167   }
168   return changed;
169 }
170 
171 // The output of average or max pooling is within the same range as its input.
HardcodeMinMaxForAverageOrMaxPool(Model * model,Operator * op)172 bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
173   auto& output_array = model->GetArray(op->outputs[0]);
174   if (output_array.minmax) {
175     return false;
176   }
177   const auto& input_array = model->GetArray(op->inputs[0]);
178   if (!input_array.minmax) {
179     return false;
180   }
181   const auto& input_minmax = input_array.GetMinMax();
182   CHECK(!output_array.minmax);
183   auto& output_minmax = output_array.GetOrCreateMinMax();
184   output_minmax.min = std::min(input_minmax.min, 0.);
185   output_minmax.max = std::max(input_minmax.max, 0.);
186   return true;
187 }
188 
HardcodeMinMaxFromFirstInput(Model * model,Operator * op)189 bool HardcodeMinMaxFromFirstInput(Model* model, Operator* op) {
190   auto& output_array = model->GetArray(op->outputs[0]);
191   if (output_array.minmax) {
192     return false;
193   }
194   const auto& input_array = model->GetArray(op->inputs[0]);
195   if (!input_array.minmax) {
196     return false;
197   }
198   const auto& input_minmax = input_array.GetMinMax();
199   CHECK(!output_array.minmax);
200   auto& output_minmax = output_array.GetOrCreateMinMax();
201   output_minmax.min = input_minmax.min;
202   output_minmax.max = input_minmax.max;
203   return true;
204 }
205 
HardcodeMinMaxForSelect(Model * model,Operator * op)206 bool HardcodeMinMaxForSelect(Model* model, Operator* op) {
207   auto& output_array = model->GetArray(op->outputs[0]);
208   if (output_array.minmax) {
209     return false;
210   }
211 
212   auto& input_array_1 = model->GetArray(op->inputs[1]);
213   auto& input_array_2 = model->GetArray(op->inputs[2]);
214 
215   if (!input_array_1.minmax && !input_array_2.minmax) {
216     return false;
217   }
218 
219   // Propagate up if one input is quantized and the other is constant.
220   if (!input_array_1.minmax &&
221       IsConstantParameterArray(*model, op->inputs[1])) {
222     auto& minmax_1 = input_array_1.GetOrCreateMinMax();
223     const auto& minmax_2 = input_array_2.GetMinMax();
224     minmax_1.min = minmax_2.min;
225     minmax_1.max = minmax_2.max;
226   }
227 
228   if (!input_array_2.minmax &&
229       IsConstantParameterArray(*model, op->inputs[2])) {
230     auto& minmax_2 = input_array_2.GetOrCreateMinMax();
231     const auto& minmax_1 = input_array_1.GetMinMax();
232     minmax_2.min = minmax_1.min;
233     minmax_2.max = minmax_1.max;
234   }
235 
236   if (!input_array_1.minmax || !input_array_2.minmax) {
237     return false;
238   }
239 
240   const auto& input_minmax_1 = input_array_1.GetMinMax();
241   const auto& input_minmax_2 = input_array_2.GetMinMax();
242 
243   CHECK_EQ(input_minmax_1.min, input_minmax_2.min);
244   CHECK_EQ(input_minmax_1.max, input_minmax_2.max);
245   CHECK(!output_array.minmax);
246   auto& output_minmax = output_array.GetOrCreateMinMax();
247   output_minmax.min = input_minmax_1.min;
248   output_minmax.max = input_minmax_1.max;
249   return true;
250 }
251 
HardcodeMinMaxForOutput(Model * model,Operator * op,double min,double max)252 bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
253                              double max) {
254   CHECK_EQ(op->outputs.size(), 1);
255   auto& output_array = model->GetArray(op->outputs[0]);
256   if (output_array.minmax) {
257     return false;
258   }
259   const auto& input_array = model->GetArray(op->inputs[0]);
260   if (!input_array.minmax) {
261     return false;
262   }
263   CHECK(!output_array.minmax);
264   auto& output_minmax = output_array.GetOrCreateMinMax();
265   output_minmax.min = min;
266   output_minmax.max = max;
267   return true;
268 }
269 
MinMaxApproximatelyEqual(const MinMax & minmax1,const MinMax & minmax2)270 bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
271   const double magnitude =
272       std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
273   const double tolerated = 1e-6 * magnitude;
274   return std::abs(minmax1.min - minmax2.min) <= tolerated &&
275          std::abs(minmax1.max - minmax2.max) <= tolerated;
276 }
277 
278 // Propagates MinMax from any of the listed arrays, to all others.
279 // If multiple of these arrays have MinMax, then these are required
280 // to agree with each other.
PropagateMinMaxAmongArrays(Model * model,const std::vector<std::string> & array_names)281 bool PropagateMinMaxAmongArrays(Model* model,
282                                 const std::vector<std::string>& array_names) {
283   std::string reference_array_name;
284   MinMax* reference_minmax = nullptr;
285   for (const std::string& array_name : array_names) {
286     if (model->GetArray(array_name).minmax) {
287       reference_array_name = array_name;
288       reference_minmax = model->GetArray(array_name).minmax.get();
289       break;
290     }
291   }
292   // No MinMax info is available to propagate.
293   if (!reference_minmax) {
294     return false;
295   }
296   bool changed = false;
297   for (const std::string& array_name : array_names) {
298     auto& array = model->GetArray(array_name);
299     if (array.minmax) {
300       CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
301           << "Both the following arrays have minmax, and they disagree: "
302           << reference_array_name << " (" << reference_minmax->min << ","
303           << reference_minmax->max << ") and " << array_name << " ("
304           << array.minmax->min << "," << array.minmax->max
305           << "). Expected that either only one of them would have minmax, or "
306              "at "
307              "least that they would agree.";
308     } else {
309       array.GetOrCreateMinMax() = *reference_minmax;
310       changed = true;
311     }
312   }
313   return changed;
314 }
315 
HardcodeMinMaxForReshape(Model * model,Operator * op)316 bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
317   Array& input = model->GetArray(op->inputs[0]);
318   Array& output = model->GetArray(op->outputs[0]);
319 
320   // If input and output both exist or do not exist, do nothing.
321   if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) {
322     return false;
323   }
324 
325   // Otherwise propagate info amongst the input and output array.
326   return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]});
327 }
328 
HardcodeMinMaxForLstmCell(Model * model,Operator * op)329 bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
330   CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
331   CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
332 
333   bool changed = false;
334   changed |= PropagateMinMaxAmongArrays(
335       model, {op->inputs[LstmCellOperator::PREV_STATE_INPUT],
336               op->outputs[LstmCellOperator::STATE_OUTPUT]});
337 
338   auto& input_activations =
339       model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
340   if (!input_activations.minmax) {
341     auto& minmax = input_activations.GetOrCreateMinMax();
342     minmax.min = -1;
343     minmax.max = 127. / 128.;
344     changed = true;
345   }
346 
347   auto& prev_output_activations =
348       model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
349   if (!prev_output_activations.minmax) {
350     auto& minmax = prev_output_activations.GetOrCreateMinMax();
351     minmax.min = -1;
352     minmax.max = 127. / 128.;
353     changed = true;
354   }
355 
356   auto& output_concat_temp =
357       model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]);
358   if (!output_concat_temp.minmax) {
359     auto& minmax = output_concat_temp.GetOrCreateMinMax();
360     minmax.min = -1;
361     minmax.max = 127. / 128.;
362     changed = true;
363   }
364 
365   auto& output_activations =
366       model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]);
367   if (!output_activations.minmax) {
368     auto& minmax = output_activations.GetOrCreateMinMax();
369     minmax.min = -1;
370     minmax.max = 127. / 128.;
371     changed = true;
372   }
373 
374   // (This comment should morph into proper documentation for
375   // quantization of LSTM models. It isn't just a local implementation detail,
376   // the training code for LSTM models needs to be adjusted to that.)
377   //
378   // Finally, output_activations_temp holds the output of the fully-connected
379   // node inside the LSTM cell. For it, we hardcode a minmax of [-8, 8].
380   // The rationale for that is given in a lengthy comment on the LstmCell
381   // quantized runtime implementation in reference_ops.h.
382   auto& output_activations_temp =
383       model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]);
384   if (!output_activations_temp.minmax) {
385     auto& minmax = output_activations_temp.GetOrCreateMinMax();
386     minmax.min = -8;
387     minmax.max = 8 * 32767. / 32768.;
388     changed = true;
389   }
390 
391   return changed;
392 }
393 
HardcodeMinMaxForPack(Model * model,Operator * op)394 bool HardcodeMinMaxForPack(Model* model, Operator* op) {
395   auto& output_array = model->GetArray(op->outputs[0]);
396   if (output_array.minmax) {
397     return false;
398   }
399 
400   // If all tensors being packed have the same min/max range, hardcode min/max
401   // for the output.
402   const auto& first_input_array = model->GetArray(op->inputs[0]);
403   if (!first_input_array.minmax) {
404     return false;
405   }
406   const auto& first_input_minmax = first_input_array.GetMinMax();
407 
408   for (size_t i = 1; i < op->inputs.size(); i++) {
409     const auto& input_array = model->GetArray(op->inputs[i]);
410     if (!input_array.minmax) {
411       return false;
412     }
413     if (first_input_minmax != input_array.GetMinMax()) {
414       return false;
415     }
416   }
417 
418   auto& output_minmax = output_array.GetOrCreateMinMax();
419   output_minmax.min = first_input_minmax.min;
420   output_minmax.max = first_input_minmax.max;
421   return true;
422 }
423 
424 }  // namespace
425 
Run(Model * model,std::size_t op_index,bool * modified)426 ::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
427                                          bool* modified) {
428   *modified = false;
429   auto it = model->operators.begin() + op_index;
430   auto* op = it->get();
431   bool changed = false;
432   switch (op->type) {
433     case OperatorType::kConv:
434       changed = HardcodeMinMaxForIm2colArray(model, op);
435       break;
436 
437     case OperatorType::kL2Normalization:
438       changed = HardcodeMinMaxForL2Normalization(model, op);
439       break;
440 
441     case OperatorType::kRelu:
442       // For any normalization other than batch norm, the quantizations ranges
443       // before and after relu are expected to be known. Having a quantization
444       // op before relu would reduce the number of bits of precision for the
445       // activation in half. So we deduce the range before relu from that after
446       // the relu. This would eliminate the need for two fake quantization nodes
447       // and would not reduce the bits of precision available for activation.
448       changed = HardcodeInputMinMaxFromOutput(model, op);
449       break;
450 
451     case OperatorType::kConcatenation:
452       changed = HardcodeMinMaxForConcatenation(model, op);
453       break;
454 
455     case OperatorType::kSplit:
456       changed = HardcodeMinMaxForSplit(model, op);
457       break;
458 
459     case OperatorType::kAveragePool:
460     case OperatorType::kMaxPool:
461       changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
462       break;
463 
464     case OperatorType::kResizeBilinear:
465     case OperatorType::kResizeNearestNeighbor:
466     case OperatorType::kSlice:
467     case OperatorType::kStridedSlice:
468     case OperatorType::kSqueeze:
469     case OperatorType::kExpandDims:
470     case OperatorType::kPad:
471     case OperatorType::kGather:
472     case OperatorType::kTranspose:
473     case OperatorType::kMean:
474     case OperatorType::kReduceMax:
475     case OperatorType::kReduceMin:
476       changed = HardcodeMinMaxFromFirstInput(model, op);
477       break;
478     case OperatorType::kPack:
479       changed = HardcodeMinMaxForPack(model, op);
480       break;
481     case OperatorType::kSum:
482       // reduce_sum is expected to change the output range. Hence
483       // a fake_quant op is necessary in the output to minimize error. However
484       // in special circumstances like when computing expected value using
485       // reduce_sum the input range and the output range matches. Hence the
486       // below code would act as a fallback. If a fake_quant node is observed in
487       // the output that takes precedence over the hard coding logic below.
488       changed = HardcodeMinMaxFromFirstInput(model, op);
489       if (changed) {
490         LOG(WARNING) << "Using the input range for output in reduce_sum op."
491                      << "This could have an impact on your model accuracy.";
492       }
493       break;
494     case OperatorType::kSelect:
495       changed = HardcodeMinMaxForSelect(model, op);
496       break;
497     case OperatorType::kLogistic:
498       // We hardcode quantization_params to: zero_point=0, scale=1/256.
499       // This choice of minmax is the one that is equivalent to that.
500       changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
501       break;
502 
503     case OperatorType::kSoftmax:
504       // We hardcode quantization_params to: zero_point=0, scale=1/256.
505       // This choice of minmax is the one that is equivalent to that.
506       changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
507       break;
508 
509     case OperatorType::kTanh:
510       // We hardcode quantization_params to: zero_point=127, scale=1/128.
511       // This choice of minmax is the one that is equivalent to that.
512       changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0);
513       break;
514 
515     case OperatorType::kLstmCell:
516       changed = HardcodeMinMaxForLstmCell(model, op);
517       break;
518 
519     case OperatorType::kReshape:
520       changed = HardcodeMinMaxForReshape(model, op);
521       break;
522 
523     default:
524       break;
525   }
526   if (changed) {
527     AddMessageF("Hardcoded min-max through %s", LogName(*op));
528   }
529   *modified = changed;
530   return ::tensorflow::Status::OK();
531 }
532 
533 }  // namespace toco
534