• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace toco {
27 
28 namespace {
29 
30 template <ArrayDataType A>
DequantizeBuffer(Array * array)31 void DequantizeBuffer(Array* array) {
32   const auto old_data = array->GetBuffer<A>().data;
33   array->buffer = nullptr;
34   array->data_type = ArrayDataType::kFloat;
35   auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
36   new_data.resize(old_data.size());
37   const auto& qparams = array->GetQuantizationParams();
38   for (int i = 0, end = old_data.size(); i < end; i++) {
39     new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
40   }
41 }
42 
FindFirstOpWithInput(Model * model,const std::string & array_name)43 std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
44     Model* model, const std::string& array_name) {
45   for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
46     for (const auto& input : it->get()->inputs) {
47       if (input == array_name) {
48         return it;
49       }
50     }
51   }
52   return model->operators.end();
53 }
54 
ClearArrayQuantizationParams(const std::string & array_name,Model * model)55 void ClearArrayQuantizationParams(const std::string& array_name, Model* model) {
56   auto* array = &model->GetArray(array_name);
57   CHECK(array->quantization_params);
58   for (auto& input_array : *model->flags.mutable_input_arrays()) {
59     if (input_array.name() == array_name) {
60       auto& qparams = *array->quantization_params;
61       const double new_std_value = 1. / qparams.scale;
62       const double new_mean_value = qparams.zero_point;
63       if (input_array.has_std_value()) {
64         CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
65       } else {
66         input_array.set_std_value(new_std_value);
67       }
68       if (input_array.has_mean_value()) {
69         CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
70       } else {
71         input_array.set_mean_value(new_mean_value);
72       }
73     }
74   }
75   array->quantization_params = nullptr;
76 }
77 
DequantizeArray(const std::string & array_name,GraphTransformation * transformation,Model * model)78 bool DequantizeArray(const std::string& array_name,
79                      GraphTransformation* transformation, Model* model) {
80   auto* array = &model->GetArray(array_name);
81   if (!array->quantization_params) {
82     return false;
83   }
84   transformation->AddMessageF("Dequantizing array: %s", array_name);
85 
86   // Dequantize any buffer
87   if (array->buffer) {
88     if (array->data_type == ArrayDataType::kUint8) {
89       DequantizeBuffer<ArrayDataType::kUint8>(array);
90     } else if (array->data_type == ArrayDataType::kInt32) {
91       DequantizeBuffer<ArrayDataType::kInt32>(array);
92     } else {
93       LOG(FATAL) << "Unhandled data type";
94     }
95     CHECK(array->data_type == ArrayDataType::kFloat);
96     CHECK(array->buffer->type == ArrayDataType::kFloat);
97 
98     // Clear quantization params, officially makes this a non-quantized array.
99     ClearArrayQuantizationParams(array_name, model);
100     return true;
101   } else {
102     array->data_type = ArrayDataType::kFloat;
103   }
104 
105   // Clear quantization params, officially makes this a non-quantized array.
106   ClearArrayQuantizationParams(array_name, model);
107 
108   if (array->buffer) {
109     return true;
110   }
111 
112   auto* op_outputting_array = GetOpWithOutput(*model, array_name);
113   if (op_outputting_array) {
114     if (op_outputting_array->type == OperatorType::kReshape) {
115       return true;
116     }
117   }
118 
119   // If there was no minmax info, we can return now. Indeed,
120   // the below only serves to create a FakeQuant node, but some arrays are
121   // quantized without MinMax (see the CHECK above) and that corresponds to
122   // places where a FakeQuant node is actually not wanted, because the
123   // quantization params are meant to be inferred in another way (e.g. bias
124   // vector for a Conv op, see their special-casing in quantize.cc).
125   if (!array->minmax) {
126     return true;
127   }
128 
129   // Determine whether to insert a FakeQuant before or after
130   // this array.
131   bool must_insert_fakequant_before = false;
132   bool must_insert_fakequant_after = false;
133   if (IsInputArray(*model, array_name)) {
134     must_insert_fakequant_after = true;
135   }
136   for (const std::string& output_array : model->flags.output_arrays()) {
137     if (array_name == output_array) {
138       must_insert_fakequant_before = true;
139     }
140   }
141   for (const auto& rnn_state : model->flags.rnn_states()) {
142     if (array_name == rnn_state.state_array()) {
143       must_insert_fakequant_after = true;
144     }
145     if (array_name == rnn_state.back_edge_source_array()) {
146       must_insert_fakequant_before = true;
147     }
148   }
149   CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
150 
151   // Create and insert the FakeQuant node
152   auto* fakequant_op = new FakeQuantOperator;
153   model->operators.emplace(FindFirstOpWithInput(model, array_name),
154                            fakequant_op);
155   const std::string& new_array_name = AvailableArrayName(*model, array_name);
156   auto& new_array = model->GetOrCreateArray(new_array_name);
157   new_array.data_type = ArrayDataType::kFloat;
158   new_array.copy_shape(array->shape());
159   new_array.GetOrCreateMinMax() = array->GetMinMax();
160   fakequant_op->minmax = std::make_unique<MinMax>();
161   *fakequant_op->minmax = array->GetMinMax();
162   fakequant_op->narrow_range = array->narrow_range;
163   if (must_insert_fakequant_before) {
164     for (const auto& op : model->operators) {
165       for (std::string& output : op->outputs) {
166         if (output == array_name) {
167           output = new_array_name;
168         }
169       }
170     }
171     fakequant_op->inputs = {new_array_name};
172     fakequant_op->outputs = {array_name};
173   } else {
174     for (const auto& op : model->operators) {
175       for (std::string& input : op->inputs) {
176         if (input == array_name) {
177           input = new_array_name;
178         }
179       }
180     }
181     fakequant_op->inputs = {array_name};
182     fakequant_op->outputs = {new_array_name};
183   }
184   return true;
185 }
186 
187 }  // namespace
188 
Run(Model * model,std::size_t op_index,bool * modified)189 ::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index,
190                                      bool* modified) {
191   *modified = false;
192   const auto op_it = model->operators.begin() + op_index;
193   auto* op = op_it->get();
194 
195   if (op->type == OperatorType::kDequantize) {
196     auto& input_array = model->GetArray(op->inputs[0]);
197     if (input_array.data_type == ArrayDataType::kFloat) {
198       return ::tensorflow::OkStatus();
199     }
200     if (input_array.final_data_type != ArrayDataType::kFloat) {
201       return ::tensorflow::OkStatus();
202     }
203     input_array.data_type = ArrayDataType::kFloat;
204     input_array.quantization_params = nullptr;
205     auto& output_array = model->GetArray(op->outputs[0]);
206     output_array.data_type = ArrayDataType::kFloat;
207     output_array.quantization_params = nullptr;
208     *modified = RemoveTrivialPassthroughOp(this, model, op_index);
209     return ::tensorflow::OkStatus();
210   }
211 
212   std::vector<std::string> arrays;
213   arrays.reserve(op->inputs.size());
214   for (const std::string& input : op->inputs) {
215     arrays.push_back(input);
216   }
217   for (const std::string& output : op->outputs) {
218     arrays.push_back(output);
219   }
220   bool changed = false;
221   for (const std::string& array : arrays) {
222     if (!model->IsOptionalArray(array)) {
223       changed |= DequantizeArray(array, this, model);
224     }
225   }
226 
227   *modified = changed;
228   return ::tensorflow::OkStatus();
229 }
230 
231 }  // namespace toco
232