1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/modify_model_interface.h"
16
17 #include <memory>
18 #include <sstream>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22
23 #include "flatbuffers/flexbuffers.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/error_reporter.h"
28 #include "tensorflow/lite/kernels/internal/compatibility.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31 #include "tensorflow/lite/schema/schema_utils.h"
32 #include "tensorflow/lite/tools/optimize/model_utils.h"
33
34 namespace tflite {
35 namespace optimize {
36
37 namespace {
38
39 // Structure to hold input tensor, op and output tensor.
40 // op must be either quantize or dequantize.
41 struct TensorOpTensor {
42 size_t subgraph_index; // index of the subgraph.
43 int32_t input_index; // index of the input tensor.
44 int32_t op_index; // index of the op.
45 int32_t output_index; // index of the output tensor.
46 int32_t model_index; // index of the added tensor in the model.
47 };
48
49 // Finds float tensors that are model inputs and is consumed by a quantize Op.
50 // The returned TensorOpTensor should have reverse order.
GetInputTensors(const TensorType & input_type,ModelT * model,ErrorReporter * error_reporter)51 std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
52 ModelT* model,
53 ErrorReporter* error_reporter) {
54 std::vector<TensorOpTensor> result;
55 // Get all input tensors.
56 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
57 subgraph_idx++) {
58 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
59 absl::flat_hash_map<TensorT*, int> input_tensors;
60 for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
61 input_idx++) {
62 TensorT* tensor = subgraph->tensors[subgraph->inputs[input_idx]].get();
63 if (tensor->type == TensorType_FLOAT32) {
64 input_tensors.insert({tensor, input_idx});
65 }
66 }
67
68 for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
69 op_idx--) {
70 OperatorT* op = subgraph->operators[op_idx].get();
71 const BuiltinOperator op_code =
72 GetBuiltinCode(model->operator_codes[op->opcode_index].get());
73 TensorT* input_tensor = subgraph->tensors[op->inputs[0]].get();
74 if (input_tensors.find(input_tensor) == input_tensors.end()) {
75 continue;
76 }
77 if (op_code != BuiltinOperator_QUANTIZE) {
78 // Currently only supports int8 and int16 quantized models.
79 TF_LITE_REPORT_ERROR(
80 error_reporter,
81 "modify_model_interface called on a model without quant/dequant.");
82 return {};
83 }
84 if (op->inputs.size() != 1) {
85 continue;
86 }
87 if (op->outputs.size() != 1) {
88 continue;
89 }
90 const int model_input_index = input_tensors[input_tensor];
91 TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
92 if (quant_output->type != TensorType_INT8 &&
93 quant_output->type != TensorType_INT16) {
94 TF_LITE_REPORT_ERROR(error_reporter,
95 "modify_model_interface currently only supports "
96 "int8 and int16 quantized models.");
97 }
98
99 // The input type must be the same as the model quantization type
100 if (input_type != quant_output->type) {
101 // An exception, allow for UINT8 input type for INT8 quantized model.
102 if (!(input_type == TensorType_UINT8 &&
103 quant_output->type == TensorType_INT8)) {
104 TF_LITE_REPORT_ERROR(
105 error_reporter,
106 "The %s input type is incompatible with %s quantized models. "
107 "To resolve this error, change the input_type to a compatible "
108 "one. "
109 "See: modify_model_interface.cc",
110 EnumNameTensorType(input_type),
111 EnumNameTensorType(quant_output->type));
112 }
113 }
114 if (quant_output->quantization == nullptr) {
115 continue;
116 }
117 result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
118 model_input_index});
119 }
120 }
121 return result;
122 }
123
124 // Finds float tensors that are model output and is consumed by a dequantize Op.
125 // The returned TensorOpTensor should have reverse order.
GetOutputTensors(const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)126 std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
127 ModelT* model,
128 ErrorReporter* error_reporter) {
129 std::vector<TensorOpTensor> result;
130 // Get all output tensors.
131 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
132 subgraph_idx++) {
133 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
134 absl::flat_hash_map<TensorT*, int> output_tensors;
135 for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
136 output_idx++) {
137 TensorT* tensor = subgraph->tensors[subgraph->outputs[output_idx]].get();
138 if (tensor->type == TensorType_FLOAT32) {
139 output_tensors.insert({tensor, output_idx});
140 }
141 }
142
143 for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
144 op_idx--) {
145 OperatorT* op = subgraph->operators[op_idx].get();
146 const BuiltinOperator op_code =
147 GetBuiltinCode(model->operator_codes[op->opcode_index].get());
148 TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get();
149 if (output_tensors.find(output_tensor) == output_tensors.end()) {
150 continue;
151 }
152 if (op_code != BuiltinOperator_DEQUANTIZE) {
153 // Currently only supports int8 and int16 quantized models.
154 TF_LITE_REPORT_ERROR(
155 error_reporter,
156 "modify_model_interface called on a model without quant/dequant.");
157 return {};
158 }
159 if (op->inputs.size() != 1) {
160 continue;
161 }
162 if (op->outputs.size() != 1) {
163 continue;
164 }
165 const int model_output_index = output_tensors[output_tensor];
166 TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
167 if (dequant_input->type != TensorType_INT8 &&
168 dequant_input->type != TensorType_INT16) {
169 // Currently only supports int8 and int16 quantized models.
170 TF_LITE_REPORT_ERROR(error_reporter,
171 "modify_model_interface currently only supports "
172 "int8 and int16 quantized models.");
173 return {};
174 }
175 if (output_type != dequant_input->type) {
176 // An exception, allow for UINT8 input type for INT8 quantized model.
177 if (!(output_type == TensorType_UINT8 &&
178 dequant_input->type == TensorType_INT8)) {
179 TF_LITE_REPORT_ERROR(
180 error_reporter,
181 "The %s output type is incompatible with %s quantized models. "
182 "To resolve this error, change the output_type to a compatible "
183 "one. "
184 "See: modify_model_interface.cc",
185 EnumNameTensorType(output_type),
186 EnumNameTensorType(dequant_input->type));
187 }
188 }
189 if (dequant_input->quantization == nullptr) {
190 continue;
191 }
192 result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
193 model_output_index});
194 }
195 }
196 return result;
197 }
198
SetInputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & inputs)199 TfLiteStatus SetInputTypeToUINT8(ModelT* model,
200 const std::vector<TensorOpTensor>& inputs) {
201 // If the input type is uint8, change float to uint8.
202 for (auto tot : inputs) {
203 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
204 TensorT* quant_tensor = subgraph->tensors[tot.output_index].get();
205 const float quant_tensor_scale = quant_tensor->quantization->scale[0];
206 const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
207 TensorT* float_tensor = subgraph->tensors[tot.input_index].get();
208 float_tensor->type = TensorType_UINT8;
209 if (float_tensor->quantization == nullptr) {
210 float_tensor->quantization = std::make_unique<QuantizationParametersT>();
211 }
212 float_tensor->quantization->scale.push_back(quant_tensor_scale);
213 float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
214 }
215 return kTfLiteOk;
216 }
217
SetOutputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & outputs)218 TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
219 const std::vector<TensorOpTensor>& outputs) {
220 // Find Quant op code index.
221 size_t quant_op_index = 0;
222 for (size_t i = 0; i < model->operator_codes.size(); ++i) {
223 if (GetBuiltinCode(model->operator_codes[i].get()) ==
224 BuiltinOperator_QUANTIZE) {
225 quant_op_index = i;
226 }
227 }
228 // If the output type is uint8, change float to uint8.
229 for (auto tot : outputs) {
230 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
231 TensorT* quant_tensor = subgraph->tensors[tot.input_index].get();
232 const float quant_tensor_scale = quant_tensor->quantization->scale[0];
233 const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
234 TensorT* float_tensor = subgraph->tensors[tot.output_index].get();
235 float_tensor->type = TensorType_UINT8;
236 if (float_tensor->quantization == nullptr) {
237 float_tensor->quantization = std::make_unique<QuantizationParametersT>();
238 }
239 float_tensor->quantization->scale.push_back(quant_tensor_scale);
240 float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
241
242 // Change op from dequant (int8 to float) to quant (int8 to uint8)
243 OperatorT* op = subgraph->operators[tot.op_index].get();
244 op->opcode_index = quant_op_index;
245 }
246 return kTfLiteOk;
247 }
248
RemoveInputTensor(ModelT * model,const std::vector<TensorOpTensor> & inputs,int32 original_number_tensors)249 TfLiteStatus RemoveInputTensor(ModelT* model,
250 const std::vector<TensorOpTensor>& inputs,
251 int32 original_number_tensors) {
252 // Consistency check to make sure that erase start from the end.
253 int last_op_index = std::numeric_limits<int32_t>::max();
254 int last_tensor_index = std::numeric_limits<int32_t>::max();
255 for (auto tot : inputs) {
256 TFLITE_DCHECK(tot.input_index < last_tensor_index);
257 TFLITE_DCHECK(tot.op_index < last_op_index);
258 last_tensor_index = tot.input_index;
259 last_op_index = tot.op_index;
260 }
261 // Removes the input tensor and the related operator.
262 for (auto tot : inputs) {
263 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
264 TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
265 TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
266 if (tot.input_index >= original_number_tensors) {
267 subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
268 }
269 subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
270 subgraph->inputs[tot.model_index] = tot.output_index;
271 }
272 return kTfLiteOk;
273 }
274
RemoveOutputTensor(ModelT * model,const std::vector<TensorOpTensor> & outputs,int32 original_number_tensors)275 TfLiteStatus RemoveOutputTensor(ModelT* model,
276 const std::vector<TensorOpTensor>& outputs,
277 int32 original_number_tensors) {
278 // Consistency check to make sure that erase start from the end.
279 int last_op_index = std::numeric_limits<int32_t>::max();
280 int last_tensor_index = std::numeric_limits<int32_t>::max();
281 for (auto tot : outputs) {
282 TFLITE_DCHECK(tot.output_index < last_tensor_index);
283 TFLITE_DCHECK(tot.op_index < last_op_index);
284 last_tensor_index = tot.output_index;
285 last_op_index = tot.op_index;
286 }
287 // Removes the output tensor and the related operator.
288 for (auto tot : outputs) {
289 SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
290 TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
291 TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
292 if (tot.output_index >= original_number_tensors) {
293 subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
294 }
295 subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
296 subgraph->outputs[tot.model_index] = tot.input_index;
297 }
298 return kTfLiteOk;
299 }
300
301
GetOriginalNumberOfTensors(const TensorType & input_type,const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)302 int GetOriginalNumberOfTensors(const TensorType& input_type,
303 const TensorType& output_type, ModelT* model,
304 ErrorReporter* error_reporter) {
305 std::vector<TensorOpTensor> outputs =
306 GetOutputTensors(output_type, model, error_reporter);
307 std::vector<TensorOpTensor> inputs =
308 GetInputTensors(input_type, model, error_reporter);
309 return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
310 }
311
312 } // namespace
313
ModifyModelInterface(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type)314 TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
315 ModelT* model, const TensorType& input_type,
316 const TensorType& output_type) {
317 tflite::StderrReporter error_reporter;
318 const int original_number_tensors = GetOriginalNumberOfTensors(
319 input_type, output_type, model, &error_reporter);
320 // Finds float tensors that are model output and are consumed by a float to
321 // int8/int16 quantize Op. Do output first since the tensors are added into
322 // input first.,
323 std::vector<TensorOpTensor> outputs =
324 GetOutputTensors(output_type, model, &error_reporter);
325 switch (output_type) {
326 case TensorType_UINT8:
327 SetOutputTypeToUINT8(model, outputs);
328 break;
329 case TensorType_INT8:
330 case TensorType_INT16:
331 RemoveOutputTensor(model, outputs, original_number_tensors);
332 break;
333 default:
334 return kTfLiteError;
335 }
336
337 // Find float tensors that are model input and is consumed by a float to
338 // int8/int16 quantize Op.
339 std::vector<TensorOpTensor> inputs =
340 GetInputTensors(input_type, model, &error_reporter);
341 switch (input_type) {
342 case TensorType_UINT8:
343 SetInputTypeToUINT8(model, inputs);
344 break;
345 case TensorType_INT8:
346 case TensorType_INT16:
347 RemoveInputTensor(model, inputs, original_number_tensors);
348 break;
349 default:
350 return kTfLiteError;
351 }
352
353 // Write to builder.
354 flatbuffers::Offset<Model> output_model_location =
355 Model::Pack(*builder, model);
356 FinishModelBuffer(*builder, output_model_location);
357
358 return kTfLiteOk;
359 }
360
ModifyModelInterface(const string & input_file,const string & output_file,const TensorType & input_type,const TensorType & output_type)361 TfLiteStatus ModifyModelInterface(const string& input_file,
362 const string& output_file,
363 const TensorType& input_type,
364 const TensorType& output_type) {
365 // Consistency Check
366 if (input_type != tflite::TensorType_INT8 &&
367 input_type != tflite::TensorType_UINT8 &&
368 input_type != tflite::TensorType_INT16) {
369 return kTfLiteError;
370 }
371 if (output_type != tflite::TensorType_INT8 &&
372 output_type != tflite::TensorType_UINT8 &&
373 output_type != tflite::TensorType_INT16) {
374 return kTfLiteError;
375 }
376
377 // Create model.
378 auto tflite_model = utils::CreateMutableModelFromFile(input_file);
379
380 auto model_builder = utils::FinishModel(tflite_model.get());
381
382 auto fixed_point_model_builder =
383 std::make_unique<flatbuffers::FlatBufferBuilder>();
384 flatbuffers::FlatBufferBuilder builder;
385
386 auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
387 output_type);
388 TFLITE_DCHECK_EQ(status, kTfLiteOk);
389
390 utils::WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());
391
392 return kTfLiteOk;
393 }
394
395 namespace {
AddUint8Dequant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)396 void AddUint8Dequant(
397 const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
398 ModelT* model) {
399 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
400 subgraph_idx++) {
401 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
402 // Add dequant to input tensors.
403 for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
404 input_idx++) {
405 const int32_t tensor_idx = subgraph->inputs[input_idx];
406 TensorT* tensor = subgraph->tensors[tensor_idx].get();
407 if (tensor->type != TensorType_FLOAT32) {
408 continue;
409 }
410 if (quant_params.find(tensor->name) != quant_params.end()) {
411 // Add uint8 tensor
412 const string added_tensor_name = tensor->name + "_uint8";
413 std::unique_ptr<TensorT> leading_op_input;
414 const std::pair<float, int32_t>& provided_quant_params =
415 quant_params.at(string(tensor->name));
416 utils::MakeTensorWithQuantParam(
417 added_tensor_name, tensor->shape, tensor->shape_signature,
418 TensorType_UINT8, provided_quant_params.first,
419 provided_quant_params.second, &leading_op_input);
420 const int32_t leading_op_input_idx = subgraph->tensors.size();
421 subgraph->tensors.push_back(std::move(leading_op_input));
422
423 // Create the leading op, which is deqantize Op.
424 std::unique_ptr<OperatorT> leading_op;
425 utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx,
426 tensor_idx);
427
428 // Insert the new op at the start of the model.
429 subgraph->operators.insert(subgraph->operators.begin(),
430 std::move(leading_op));
431 }
432 }
433 }
434 }
435
AddUint8Quant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)436 void AddUint8Quant(
437 const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
438 ModelT* model) {
439 for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
440 subgraph_idx++) {
441 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
442 // Add quant to output tensors.
443 for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
444 output_idx++) {
445 const int32_t tensor_idx = subgraph->outputs[output_idx];
446 TensorT* tensor = subgraph->tensors[tensor_idx].get();
447 if (tensor->type != TensorType_FLOAT32) {
448 continue;
449 }
450 if (quant_params.find(tensor->name) != quant_params.end()) {
451 // Add uint8 tensor
452 const string added_tensor_name = tensor->name + "_uint8";
453 std::unique_ptr<TensorT> tailing_op_output;
454 const std::pair<float, int32_t>& provided_quant_params =
455 quant_params.at(string(tensor->name));
456 utils::MakeTensorWithQuantParam(
457 added_tensor_name, tensor->shape, tensor->shape_signature,
458 TensorType_UINT8, provided_quant_params.first,
459 provided_quant_params.second, &tailing_op_output);
460 const int32_t tailing_op_output_idx = subgraph->tensors.size();
461 subgraph->tensors.push_back(std::move(tailing_op_output));
462
463 // Create the tailing op, which is Qantize Op.
464 std::unique_ptr<OperatorT> tailing_op;
465 utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx,
466 tailing_op_output_idx);
467
468 // Insert the new op at the end of the model.
469 subgraph->operators.push_back(std::move(tailing_op));
470 }
471 }
472 }
473 }
474 } // namespace
475
Uint8QuantizeModelInputsOutputs(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,const std::unordered_map<string,std::pair<float,int32_t>> & input_quant_params,const std::unordered_map<string,std::pair<float,int32_t>> & output_quant_params)476 TfLiteStatus Uint8QuantizeModelInputsOutputs(
477 flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
478 const std::unordered_map<string, std::pair<float, int32_t>>&
479 input_quant_params,
480 const std::unordered_map<string, std::pair<float, int32_t>>&
481 output_quant_params) {
482 std::unique_ptr<ModelT> model;
483 model.reset(input_model->UnPack());
484 // Add Dequant for inputs.
485 AddUint8Dequant(input_quant_params, model.get());
486
487 // Add Quant for outputs.
488 AddUint8Quant(output_quant_params, model.get());
489
490 // Output model.
491 flatbuffers::Offset<Model> output_model_location =
492 Model::Pack(*builder, model.get());
493 FinishModelBuffer(*builder, output_model_location);
494
495 return kTfLiteOk;
496 }
497
498 } // namespace optimize
499 } // namespace tflite
500