1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace toco {
27
28 namespace {
29
30 template <ArrayDataType A>
DequantizeBuffer(Array * array)31 void DequantizeBuffer(Array* array) {
32 const auto old_data = array->GetBuffer<A>().data;
33 array->buffer = nullptr;
34 array->data_type = ArrayDataType::kFloat;
35 auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
36 new_data.resize(old_data.size());
37 const auto& qparams = array->GetQuantizationParams();
38 for (int i = 0, end = old_data.size(); i < end; i++) {
39 new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
40 }
41 }
42
FindFirstOpWithInput(Model * model,const std::string & array_name)43 std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
44 Model* model, const std::string& array_name) {
45 for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
46 for (const auto& input : it->get()->inputs) {
47 if (input == array_name) {
48 return it;
49 }
50 }
51 }
52 return model->operators.end();
53 }
54
ClearArrayQuantizationParams(const std::string & array_name,Model * model)55 void ClearArrayQuantizationParams(const std::string& array_name, Model* model) {
56 auto* array = &model->GetArray(array_name);
57 CHECK(array->quantization_params);
58 for (auto& input_array : *model->flags.mutable_input_arrays()) {
59 if (input_array.name() == array_name) {
60 auto& qparams = *array->quantization_params;
61 const double new_std_value = 1. / qparams.scale;
62 const double new_mean_value = qparams.zero_point;
63 if (input_array.has_std_value()) {
64 CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
65 } else {
66 input_array.set_std_value(new_std_value);
67 }
68 if (input_array.has_mean_value()) {
69 CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
70 } else {
71 input_array.set_mean_value(new_mean_value);
72 }
73 }
74 }
75 array->quantization_params = nullptr;
76 }
77
DequantizeArray(const std::string & array_name,GraphTransformation * transformation,Model * model)78 bool DequantizeArray(const std::string& array_name,
79 GraphTransformation* transformation, Model* model) {
80 auto* array = &model->GetArray(array_name);
81 if (!array->quantization_params) {
82 return false;
83 }
84 transformation->AddMessageF("Dequantizing array: %s", array_name);
85
86 // Dequantize any buffer
87 if (array->buffer) {
88 if (array->data_type == ArrayDataType::kUint8) {
89 DequantizeBuffer<ArrayDataType::kUint8>(array);
90 } else if (array->data_type == ArrayDataType::kInt32) {
91 DequantizeBuffer<ArrayDataType::kInt32>(array);
92 } else {
93 LOG(FATAL) << "Unhandled data type";
94 }
95 CHECK(array->data_type == ArrayDataType::kFloat);
96 CHECK(array->buffer->type == ArrayDataType::kFloat);
97
98 // Clear quantization params, officially makes this a non-quantized array.
99 ClearArrayQuantizationParams(array_name, model);
100 return true;
101 } else {
102 array->data_type = ArrayDataType::kFloat;
103 }
104
105 // Clear quantization params, officially makes this a non-quantized array.
106 ClearArrayQuantizationParams(array_name, model);
107
108 if (array->buffer) {
109 return true;
110 }
111
112 auto* op_outputting_array = GetOpWithOutput(*model, array_name);
113 if (op_outputting_array) {
114 if (op_outputting_array->type == OperatorType::kReshape) {
115 return true;
116 }
117 }
118
119 // If there was no minmax info, we can return now. Indeed,
120 // the below only serves to create a FakeQuant node, but some arrays are
121 // quantized without MinMax (see the CHECK above) and that corresponds to
122 // places where a FakeQuant node is actually not wanted, because the
123 // quantization params are meant to be inferred in another way (e.g. bias
124 // vector for a Conv op, see their special-casing in quantize.cc).
125 if (!array->minmax) {
126 return true;
127 }
128
129 // Determine whether to insert a FakeQuant before or after
130 // this array.
131 bool must_insert_fakequant_before = false;
132 bool must_insert_fakequant_after = false;
133 if (IsInputArray(*model, array_name)) {
134 must_insert_fakequant_after = true;
135 }
136 for (const std::string& output_array : model->flags.output_arrays()) {
137 if (array_name == output_array) {
138 must_insert_fakequant_before = true;
139 }
140 }
141 for (const auto& rnn_state : model->flags.rnn_states()) {
142 if (array_name == rnn_state.state_array()) {
143 must_insert_fakequant_after = true;
144 }
145 if (array_name == rnn_state.back_edge_source_array()) {
146 must_insert_fakequant_before = true;
147 }
148 }
149 CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
150
151 // Create and insert the FakeQuant node
152 auto* fakequant_op = new FakeQuantOperator;
153 model->operators.emplace(FindFirstOpWithInput(model, array_name),
154 fakequant_op);
155 const std::string& new_array_name = AvailableArrayName(*model, array_name);
156 auto& new_array = model->GetOrCreateArray(new_array_name);
157 new_array.data_type = ArrayDataType::kFloat;
158 new_array.copy_shape(array->shape());
159 new_array.GetOrCreateMinMax() = array->GetMinMax();
160 fakequant_op->minmax = std::make_unique<MinMax>();
161 *fakequant_op->minmax = array->GetMinMax();
162 fakequant_op->narrow_range = array->narrow_range;
163 if (must_insert_fakequant_before) {
164 for (const auto& op : model->operators) {
165 for (std::string& output : op->outputs) {
166 if (output == array_name) {
167 output = new_array_name;
168 }
169 }
170 }
171 fakequant_op->inputs = {new_array_name};
172 fakequant_op->outputs = {array_name};
173 } else {
174 for (const auto& op : model->operators) {
175 for (std::string& input : op->inputs) {
176 if (input == array_name) {
177 input = new_array_name;
178 }
179 }
180 }
181 fakequant_op->inputs = {array_name};
182 fakequant_op->outputs = {new_array_name};
183 }
184 return true;
185 }
186
187 } // namespace
188
Run(Model * model,std::size_t op_index,bool * modified)189 ::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index,
190 bool* modified) {
191 *modified = false;
192 const auto op_it = model->operators.begin() + op_index;
193 auto* op = op_it->get();
194
195 if (op->type == OperatorType::kDequantize) {
196 auto& input_array = model->GetArray(op->inputs[0]);
197 if (input_array.data_type == ArrayDataType::kFloat) {
198 return ::tensorflow::OkStatus();
199 }
200 if (input_array.final_data_type != ArrayDataType::kFloat) {
201 return ::tensorflow::OkStatus();
202 }
203 input_array.data_type = ArrayDataType::kFloat;
204 input_array.quantization_params = nullptr;
205 auto& output_array = model->GetArray(op->outputs[0]);
206 output_array.data_type = ArrayDataType::kFloat;
207 output_array.quantization_params = nullptr;
208 *modified = RemoveTrivialPassthroughOp(this, model, op_index);
209 return ::tensorflow::OkStatus();
210 }
211
212 std::vector<std::string> arrays;
213 arrays.reserve(op->inputs.size());
214 for (const std::string& input : op->inputs) {
215 arrays.push_back(input);
216 }
217 for (const std::string& output : op->outputs) {
218 arrays.push_back(output);
219 }
220 bool changed = false;
221 for (const std::string& array : arrays) {
222 if (!model->IsOptionalArray(array)) {
223 changed |= DequantizeArray(array, this, model);
224 }
225 }
226
227 *modified = changed;
228 return ::tensorflow::OkStatus();
229 }
230
231 } // namespace toco
232