1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/modify_model_interface.h"
16
17 #include <memory>
18 #include <sstream>
19 #include <unordered_set>
20
21 #include "flatbuffers/flexbuffers.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/error_reporter.h"
25 #include "tensorflow/lite/kernels/internal/compatibility.h"
26 #include "tensorflow/lite/model.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_utils.h"
29 #include "tensorflow/lite/tools/optimize/model_utils.h"
30
31 namespace tflite {
32 namespace optimize {
33
34 namespace {
35
36 // Structure to hold input tensor, op and output tensor.
37 // op must be either quantize or dequantize.
38 struct TensorOpTensor {
39 size_t subgraph_index; // index of the subgraph.
40 int32_t input_index; // index of the input tensor.
41 int32_t op_index; // index of the op.
42 int32_t output_index; // index of the output tensor.
43 int32_t model_index; // index of the added tensor in the model.
44 };
45
46 // Finds float tensors that are model inputs and is consumed by a quantize Op.
47 // The returned TensorOpTensor should have reverse order.
GetInputTensors(const TensorType & input_type,ModelT * model,ErrorReporter * error_reporter)48 std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
49 ModelT* model,
50 ErrorReporter* error_reporter) {
51 std::vector<TensorOpTensor> result;
52 // Get all input tensors.
53 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
54 subgraph_idx++) {
55 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
56 std::unordered_map<TensorT*, int> input_tensors;
57 for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
58 input_idx++) {
59 TensorT* tensor = subgraph->tensors[subgraph->inputs[input_idx]].get();
60 if (tensor->type == TensorType_FLOAT32) {
61 input_tensors.insert({tensor, input_idx});
62 }
63 }
64
65 for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
66 op_idx--) {
67 OperatorT* op = subgraph->operators[op_idx].get();
68 const BuiltinOperator op_code =
69 GetBuiltinCode(model->operator_codes[op->opcode_index].get());
70 TensorT* input_tensor = subgraph->tensors[op->inputs[0]].get();
71 if (input_tensors.find(input_tensor) == input_tensors.end()) {
72 continue;
73 }
74 if (op_code != BuiltinOperator_QUANTIZE) {
75 // Currently only supports int8 and int16 quantized models.
76 TF_LITE_REPORT_ERROR(
77 error_reporter,
78 "modify_model_interface called on a model without quant/dequant.");
79 return {};
80 }
81 if (op->inputs.size() != 1) {
82 continue;
83 }
84 if (op->outputs.size() != 1) {
85 continue;
86 }
87 const int model_input_index = input_tensors[input_tensor];
88 TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
89 if (quant_output->type != TensorType_INT8 &&
90 quant_output->type != TensorType_INT16) {
91 TF_LITE_REPORT_ERROR(error_reporter,
92 "modify_model_interface currently only supports "
93 "int8 and int16 quantized models.");
94 }
95
96 // The input type must be the same as the model quantization type
97 if (input_type != quant_output->type) {
98 // An exception, allow for UINT8 input type for INT8 quantized model.
99 if (!(input_type == TensorType_UINT8 &&
100 quant_output->type == TensorType_INT8)) {
101 TF_LITE_REPORT_ERROR(
102 error_reporter,
103 "The %s input type is incompatible with %s quantized models. "
104 "To resolve this error, change the input_type to a compatible "
105 "one. "
106 "See: modify_model_interface.cc",
107 EnumNameTensorType(input_type),
108 EnumNameTensorType(quant_output->type));
109 }
110 }
111 if (quant_output->quantization == nullptr) {
112 continue;
113 }
114 result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
115 model_input_index});
116 }
117 }
118 return result;
119 }
120
121 // Finds float tensors that are model output and is consumed by a dequantize Op.
122 // The returned TensorOpTensor should have reverse order.
GetOutputTensors(const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)123 std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
124 ModelT* model,
125 ErrorReporter* error_reporter) {
126 std::vector<TensorOpTensor> result;
127 // Get all output tensors.
128 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
129 subgraph_idx++) {
130 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
131 std::unordered_map<TensorT*, int> output_tensors;
132 for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
133 output_idx++) {
134 TensorT* tensor = subgraph->tensors[subgraph->outputs[output_idx]].get();
135 if (tensor->type == TensorType_FLOAT32) {
136 output_tensors.insert({tensor, output_idx});
137 }
138 }
139
140 for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
141 op_idx--) {
142 OperatorT* op = subgraph->operators[op_idx].get();
143 const BuiltinOperator op_code =
144 GetBuiltinCode(model->operator_codes[op->opcode_index].get());
145 TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get();
146 if (output_tensors.find(output_tensor) == output_tensors.end()) {
147 continue;
148 }
149 if (op_code != BuiltinOperator_DEQUANTIZE) {
150 // Currently only supports int8 and int16 quantized models.
151 TF_LITE_REPORT_ERROR(
152 error_reporter,
153 "modify_model_interface called on a model without quant/dequant.");
154 return {};
155 }
156 if (op->inputs.size() != 1) {
157 continue;
158 }
159 if (op->outputs.size() != 1) {
160 continue;
161 }
162 const int model_output_index = output_tensors[output_tensor];
163 TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
164 if (dequant_input->type != TensorType_INT8 &&
165 dequant_input->type != TensorType_INT16) {
166 // Currently only supports int8 and int16 quantized models.
167 TF_LITE_REPORT_ERROR(error_reporter,
168 "modify_model_interface currently only supports "
169 "int8 and int16 quantized models.");
170 return {};
171 }
172 if (output_type != dequant_input->type) {
173 // An exception, allow for UINT8 input type for INT8 quantized model.
174 if (!(output_type == TensorType_UINT8 &&
175 dequant_input->type == TensorType_INT8)) {
176 TF_LITE_REPORT_ERROR(
177 error_reporter,
178 "The %s output type is incompatible with %s quantized models. "
179 "To resolve this error, change the output_type to a compatible "
180 "one. "
181 "See: modify_model_interface.cc",
182 EnumNameTensorType(output_type),
183 EnumNameTensorType(dequant_input->type));
184 }
185 }
186 if (dequant_input->quantization == nullptr) {
187 continue;
188 }
189 result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
190 model_output_index});
191 }
192 }
193 return result;
194 }
195
SetInputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & inputs)196 TfLiteStatus SetInputTypeToUINT8(ModelT* model,
197 const std::vector<TensorOpTensor>& inputs) {
198 // If the input type is uint8, change float to uint8.
199 for (auto tot : inputs) {
200 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
201 TensorT* quant_tensor = subgraph->tensors[tot.output_index].get();
202 const float quant_tensor_scale = quant_tensor->quantization->scale[0];
203 const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
204 TensorT* float_tensor = subgraph->tensors[tot.input_index].get();
205 float_tensor->type = TensorType_UINT8;
206 if (float_tensor->quantization == nullptr) {
207 float_tensor->quantization = absl::make_unique<QuantizationParametersT>();
208 }
209 float_tensor->quantization->scale.push_back(quant_tensor_scale);
210 float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
211 }
212 return kTfLiteOk;
213 }
214
SetOutputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & outputs)215 TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
216 const std::vector<TensorOpTensor>& outputs) {
217 // Find Quant op code index.
218 size_t quant_op_index = 0;
219 for (size_t i = 0; i < model->operator_codes.size(); ++i) {
220 if (GetBuiltinCode(model->operator_codes[i].get()) ==
221 BuiltinOperator_QUANTIZE) {
222 quant_op_index = i;
223 }
224 }
225 // If the output type is uint8, change float to uint8.
226 for (auto tot : outputs) {
227 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
228 TensorT* quant_tensor = subgraph->tensors[tot.input_index].get();
229 const float quant_tensor_scale = quant_tensor->quantization->scale[0];
230 const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
231 TensorT* float_tensor = subgraph->tensors[tot.output_index].get();
232 float_tensor->type = TensorType_UINT8;
233 if (float_tensor->quantization == nullptr) {
234 float_tensor->quantization = absl::make_unique<QuantizationParametersT>();
235 }
236 float_tensor->quantization->scale.push_back(quant_tensor_scale);
237 float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
238
239 // Change op from dequant (int8 to float) to quant (int8 to uint8)
240 OperatorT* op = subgraph->operators[tot.op_index].get();
241 op->opcode_index = quant_op_index;
242 }
243 return kTfLiteOk;
244 }
245
RemoveInputTensor(ModelT * model,const std::vector<TensorOpTensor> & inputs,int32 original_number_tensors)246 TfLiteStatus RemoveInputTensor(ModelT* model,
247 const std::vector<TensorOpTensor>& inputs,
248 int32 original_number_tensors) {
249 // Consistency check to make sure that erase start from the end.
250 int last_op_index = std::numeric_limits<int32_t>::max();
251 int last_tensor_index = std::numeric_limits<int32_t>::max();
252 for (auto tot : inputs) {
253 TFLITE_DCHECK(tot.input_index < last_tensor_index);
254 TFLITE_DCHECK(tot.op_index < last_op_index);
255 last_tensor_index = tot.input_index;
256 last_op_index = tot.op_index;
257 }
258 // Removes the input tensor and the related operator.
259 for (auto tot : inputs) {
260 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
261 TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
262 TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
263 if (tot.input_index >= original_number_tensors) {
264 subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
265 }
266 subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
267 subgraph->inputs[tot.model_index] = tot.output_index;
268 }
269 return kTfLiteOk;
270 }
271
RemoveOutputTensor(ModelT * model,const std::vector<TensorOpTensor> & outputs,int32 original_number_tensors)272 TfLiteStatus RemoveOutputTensor(ModelT* model,
273 const std::vector<TensorOpTensor>& outputs,
274 int32 original_number_tensors) {
275 // Consistency check to make sure that erase start from the end.
276 int last_op_index = std::numeric_limits<int32_t>::max();
277 int last_tensor_index = std::numeric_limits<int32_t>::max();
278 for (auto tot : outputs) {
279 TFLITE_DCHECK(tot.output_index < last_tensor_index);
280 TFLITE_DCHECK(tot.op_index < last_op_index);
281 last_tensor_index = tot.output_index;
282 last_op_index = tot.op_index;
283 }
284 // Removes the output tensor and the related operator.
285 for (auto tot : outputs) {
286 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
287 TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
288 TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
289 if (tot.output_index >= original_number_tensors) {
290 subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
291 }
292 subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
293 subgraph->outputs[tot.model_index] = tot.input_index;
294 }
295 return kTfLiteOk;
296 }
297
298
GetOriginalNumberOfTensors(const TensorType & input_type,const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)299 int GetOriginalNumberOfTensors(const TensorType& input_type,
300 const TensorType& output_type, ModelT* model,
301 ErrorReporter* error_reporter) {
302 std::vector<TensorOpTensor> outputs =
303 GetOutputTensors(output_type, model, error_reporter);
304 std::vector<TensorOpTensor> inputs =
305 GetInputTensors(input_type, model, error_reporter);
306 return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
307 }
308
309 } // namespace
310
ModifyModelInterface(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type)311 TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
312 ModelT* model, const TensorType& input_type,
313 const TensorType& output_type) {
314 tflite::StderrReporter error_reporter;
315 const int original_number_tensors = GetOriginalNumberOfTensors(
316 input_type, output_type, model, &error_reporter);
317 // Finds float tensors that are model output and are consumed by a float to
318 // int8/int16 quantize Op. Do output first since the tensors are added into
319 // input first.,
320 std::vector<TensorOpTensor> outputs =
321 GetOutputTensors(output_type, model, &error_reporter);
322 switch (output_type) {
323 case TensorType_UINT8:
324 SetOutputTypeToUINT8(model, outputs);
325 break;
326 case TensorType_INT8:
327 case TensorType_INT16:
328 RemoveOutputTensor(model, outputs, original_number_tensors);
329 break;
330 default:
331 return kTfLiteError;
332 }
333
334 // Find float tensors that are model input and is consumed by a float to
335 // int8/int16 quantize Op.
336 std::vector<TensorOpTensor> inputs =
337 GetInputTensors(input_type, model, &error_reporter);
338 switch (input_type) {
339 case TensorType_UINT8:
340 SetInputTypeToUINT8(model, inputs);
341 break;
342 case TensorType_INT8:
343 case TensorType_INT16:
344 RemoveInputTensor(model, inputs, original_number_tensors);
345 break;
346 default:
347 return kTfLiteError;
348 }
349
350 // Write to builder.
351 flatbuffers::Offset<Model> output_model_location =
352 Model::Pack(*builder, model);
353 FinishModelBuffer(*builder, output_model_location);
354
355 return kTfLiteOk;
356 }
357
ModifyModelInterface(const string & input_file,const string & output_file,const TensorType & input_type,const TensorType & output_type)358 TfLiteStatus ModifyModelInterface(const string& input_file,
359 const string& output_file,
360 const TensorType& input_type,
361 const TensorType& output_type) {
362 // Consistency Check
363 if (input_type != tflite::TensorType_INT8 &&
364 input_type != tflite::TensorType_UINT8 &&
365 input_type != tflite::TensorType_INT16) {
366 return kTfLiteError;
367 }
368 if (output_type != tflite::TensorType_INT8 &&
369 output_type != tflite::TensorType_UINT8 &&
370 output_type != tflite::TensorType_INT16) {
371 return kTfLiteError;
372 }
373
374 // Create model.
375 auto tflite_model = utils::CreateMutableModelFromFile(input_file);
376
377 auto model_builder = utils::FinishModel(tflite_model.get());
378
379 auto fixed_point_model_builder =
380 absl::make_unique<flatbuffers::FlatBufferBuilder>();
381 flatbuffers::FlatBufferBuilder builder;
382
383 auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
384 output_type);
385 TFLITE_DCHECK_EQ(status, kTfLiteOk);
386
387 utils::WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());
388
389 return kTfLiteOk;
390 }
391
392 namespace {
AddUint8Dequant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)393 void AddUint8Dequant(
394 const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
395 ModelT* model) {
396 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
397 subgraph_idx++) {
398 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
399 // Add dequant to input tensors.
400 for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
401 input_idx++) {
402 const int32_t tensor_idx = subgraph->inputs[input_idx];
403 TensorT* tensor = subgraph->tensors[tensor_idx].get();
404 if (tensor->type != TensorType_FLOAT32) {
405 continue;
406 }
407 if (quant_params.find(tensor->name) != quant_params.end()) {
408 // Add uint8 tensor
409 const string added_tensor_name = tensor->name + "_uint8";
410 std::unique_ptr<TensorT> leading_op_input;
411 const std::pair<float, int32_t>& provided_quant_params =
412 quant_params.at(string(tensor->name));
413 utils::MakeTensorWithQuantParam(
414 added_tensor_name, tensor->shape, tensor->shape_signature,
415 TensorType_UINT8, provided_quant_params.first,
416 provided_quant_params.second, &leading_op_input);
417 const int32_t leading_op_input_idx = subgraph->tensors.size();
418 subgraph->tensors.push_back(std::move(leading_op_input));
419
420 // Create the leading op, which is deqantize Op.
421 std::unique_ptr<OperatorT> leading_op;
422 utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx,
423 tensor_idx);
424
425 // Insert the new op at the start of the model.
426 subgraph->operators.insert(subgraph->operators.begin(),
427 std::move(leading_op));
428 }
429 }
430 }
431 }
432
AddUint8Quant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)433 void AddUint8Quant(
434 const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
435 ModelT* model) {
436 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
437 subgraph_idx++) {
438 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
439 // Add quant to output tensors.
440 for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
441 output_idx++) {
442 const int32_t tensor_idx = subgraph->outputs[output_idx];
443 TensorT* tensor = subgraph->tensors[tensor_idx].get();
444 if (tensor->type != TensorType_FLOAT32) {
445 continue;
446 }
447 if (quant_params.find(tensor->name) != quant_params.end()) {
448 // Add uint8 tensor
449 const string added_tensor_name = tensor->name + "_uint8";
450 std::unique_ptr<TensorT> tailing_op_output;
451 const std::pair<float, int32_t>& provided_quant_params =
452 quant_params.at(string(tensor->name));
453 utils::MakeTensorWithQuantParam(
454 added_tensor_name, tensor->shape, tensor->shape_signature,
455 TensorType_UINT8, provided_quant_params.first,
456 provided_quant_params.second, &tailing_op_output);
457 const int32_t tailing_op_output_idx = subgraph->tensors.size();
458 subgraph->tensors.push_back(std::move(tailing_op_output));
459
460 // Create the tailing op, which is Qantize Op.
461 std::unique_ptr<OperatorT> tailing_op;
462 utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx,
463 tailing_op_output_idx);
464
465 // Insert the new op at the end of the model.
466 subgraph->operators.push_back(std::move(tailing_op));
467 }
468 }
469 }
470 }
471 } // namespace
472
Uint8QuantizeModelInputsOutputs(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,const std::unordered_map<string,std::pair<float,int32_t>> & input_quant_params,const std::unordered_map<string,std::pair<float,int32_t>> & output_quant_params)473 TfLiteStatus Uint8QuantizeModelInputsOutputs(
474 flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
475 const std::unordered_map<string, std::pair<float, int32_t>>&
476 input_quant_params,
477 const std::unordered_map<string, std::pair<float, int32_t>>&
478 output_quant_params) {
479 std::unique_ptr<ModelT> model;
480 model.reset(input_model->UnPack());
481 // Add Dequant for inputs.
482 AddUint8Dequant(input_quant_params, model.get());
483
484 // Add Quant for outputs.
485 AddUint8Quant(output_quant_params, model.get());
486
487 // Output model.
488 flatbuffers::Offset<Model> output_model_location =
489 Model::Pack(*builder, model.get());
490 FinishModelBuffer(*builder, output_model_location);
491
492 return kTfLiteOk;
493 }
494
495 } // namespace optimize
496 } // namespace tflite
497