• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/modify_model_interface.h"
16 
17 #include <memory>
18 #include <sstream>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "flatbuffers/flexbuffers.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/error_reporter.h"
28 #include "tensorflow/lite/kernels/internal/compatibility.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31 #include "tensorflow/lite/schema/schema_utils.h"
32 #include "tensorflow/lite/tools/optimize/model_utils.h"
33 
34 namespace tflite {
35 namespace optimize {
36 
37 namespace {
38 
39 // Structure to hold input tensor, op and output tensor.
40 // op must be either quantize or dequantize.
41 struct TensorOpTensor {
42   size_t subgraph_index;  // index of the subgraph.
43   int32_t input_index;    // index of the input tensor.
44   int32_t op_index;       // index of the op.
45   int32_t output_index;   // index of the output tensor.
46   int32_t model_index;    // index of the added tensor in the model.
47 };
48 
49 // Finds float tensors that are model inputs and is consumed by a quantize Op.
50 // The returned TensorOpTensor should have reverse order.
GetInputTensors(const TensorType & input_type,ModelT * model,ErrorReporter * error_reporter)51 std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
52                                             ModelT* model,
53                                             ErrorReporter* error_reporter) {
54   std::vector<TensorOpTensor> result;
55   // Get all input tensors.
56   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
57        subgraph_idx++) {
58     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
59     absl::flat_hash_map<TensorT*, int> input_tensors;
60     for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
61          input_idx++) {
62       TensorT* tensor = subgraph->tensors[subgraph->inputs[input_idx]].get();
63       if (tensor->type == TensorType_FLOAT32) {
64         input_tensors.insert({tensor, input_idx});
65       }
66     }
67 
68     for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
69          op_idx--) {
70       OperatorT* op = subgraph->operators[op_idx].get();
71       const BuiltinOperator op_code =
72           GetBuiltinCode(model->operator_codes[op->opcode_index].get());
73       TensorT* input_tensor = subgraph->tensors[op->inputs[0]].get();
74       if (input_tensors.find(input_tensor) == input_tensors.end()) {
75         continue;
76       }
77       if (op_code != BuiltinOperator_QUANTIZE) {
78         // Currently only supports int8 and int16 quantized models.
79         TF_LITE_REPORT_ERROR(
80             error_reporter,
81             "modify_model_interface called on a model without quant/dequant.");
82         return {};
83       }
84       if (op->inputs.size() != 1) {
85         continue;
86       }
87       if (op->outputs.size() != 1) {
88         continue;
89       }
90       const int model_input_index = input_tensors[input_tensor];
91       TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
92       if (quant_output->type != TensorType_INT8 &&
93           quant_output->type != TensorType_INT16) {
94         TF_LITE_REPORT_ERROR(error_reporter,
95                              "modify_model_interface currently only supports "
96                              "int8 and int16 quantized models.");
97       }
98 
99       // The input type must be the same as the model quantization type
100       if (input_type != quant_output->type) {
101         // An exception, allow for UINT8 input type for INT8 quantized model.
102         if (!(input_type == TensorType_UINT8 &&
103               quant_output->type == TensorType_INT8)) {
104           TF_LITE_REPORT_ERROR(
105               error_reporter,
106               "The %s input type is incompatible with %s quantized models. "
107               "To resolve this error, change the input_type to a compatible "
108               "one. "
109               "See: modify_model_interface.cc",
110               EnumNameTensorType(input_type),
111               EnumNameTensorType(quant_output->type));
112         }
113       }
114       if (quant_output->quantization == nullptr) {
115         continue;
116       }
117       result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
118                         model_input_index});
119     }
120   }
121   return result;
122 }
123 
124 // Finds float tensors that are model output and is consumed by a dequantize Op.
125 // The returned TensorOpTensor should have reverse order.
GetOutputTensors(const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)126 std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
127                                              ModelT* model,
128                                              ErrorReporter* error_reporter) {
129   std::vector<TensorOpTensor> result;
130   // Get all output tensors.
131   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
132        subgraph_idx++) {
133     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
134     absl::flat_hash_map<TensorT*, int> output_tensors;
135     for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
136          output_idx++) {
137       TensorT* tensor = subgraph->tensors[subgraph->outputs[output_idx]].get();
138       if (tensor->type == TensorType_FLOAT32) {
139         output_tensors.insert({tensor, output_idx});
140       }
141     }
142 
143     for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
144          op_idx--) {
145       OperatorT* op = subgraph->operators[op_idx].get();
146       const BuiltinOperator op_code =
147           GetBuiltinCode(model->operator_codes[op->opcode_index].get());
148       TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get();
149       if (output_tensors.find(output_tensor) == output_tensors.end()) {
150         continue;
151       }
152       if (op_code != BuiltinOperator_DEQUANTIZE) {
153         // Currently only supports int8 and int16 quantized models.
154         TF_LITE_REPORT_ERROR(
155             error_reporter,
156             "modify_model_interface called on a model without quant/dequant.");
157         return {};
158       }
159       if (op->inputs.size() != 1) {
160         continue;
161       }
162       if (op->outputs.size() != 1) {
163         continue;
164       }
165       const int model_output_index = output_tensors[output_tensor];
166       TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
167       if (dequant_input->type != TensorType_INT8 &&
168           dequant_input->type != TensorType_INT16) {
169         // Currently only supports int8 and int16 quantized models.
170         TF_LITE_REPORT_ERROR(error_reporter,
171                              "modify_model_interface currently only supports "
172                              "int8 and int16 quantized models.");
173         return {};
174       }
175       if (output_type != dequant_input->type) {
176         // An exception, allow for UINT8 input type for INT8 quantized model.
177         if (!(output_type == TensorType_UINT8 &&
178               dequant_input->type == TensorType_INT8)) {
179           TF_LITE_REPORT_ERROR(
180               error_reporter,
181               "The %s output type is incompatible with %s quantized models. "
182               "To resolve this error, change the output_type to a compatible "
183               "one. "
184               "See: modify_model_interface.cc",
185               EnumNameTensorType(output_type),
186               EnumNameTensorType(dequant_input->type));
187         }
188       }
189       if (dequant_input->quantization == nullptr) {
190         continue;
191       }
192       result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
193                         model_output_index});
194     }
195   }
196   return result;
197 }
198 
SetInputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & inputs)199 TfLiteStatus SetInputTypeToUINT8(ModelT* model,
200                                  const std::vector<TensorOpTensor>& inputs) {
201   // If the input type is uint8, change float to uint8.
202   for (auto tot : inputs) {
203     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
204     TensorT* quant_tensor = subgraph->tensors[tot.output_index].get();
205     const float quant_tensor_scale = quant_tensor->quantization->scale[0];
206     const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
207     TensorT* float_tensor = subgraph->tensors[tot.input_index].get();
208     float_tensor->type = TensorType_UINT8;
209     if (float_tensor->quantization == nullptr) {
210       float_tensor->quantization = std::make_unique<QuantizationParametersT>();
211     }
212     float_tensor->quantization->scale.push_back(quant_tensor_scale);
213     float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
214   }
215   return kTfLiteOk;
216 }
217 
SetOutputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & outputs)218 TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
219                                   const std::vector<TensorOpTensor>& outputs) {
220   // Find Quant op code index.
221   size_t quant_op_index = 0;
222   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
223     if (GetBuiltinCode(model->operator_codes[i].get()) ==
224         BuiltinOperator_QUANTIZE) {
225       quant_op_index = i;
226     }
227   }
228   // If the output type is uint8, change float to uint8.
229   for (auto tot : outputs) {
230     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
231     TensorT* quant_tensor = subgraph->tensors[tot.input_index].get();
232     const float quant_tensor_scale = quant_tensor->quantization->scale[0];
233     const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
234     TensorT* float_tensor = subgraph->tensors[tot.output_index].get();
235     float_tensor->type = TensorType_UINT8;
236     if (float_tensor->quantization == nullptr) {
237       float_tensor->quantization = std::make_unique<QuantizationParametersT>();
238     }
239     float_tensor->quantization->scale.push_back(quant_tensor_scale);
240     float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
241 
242     // Change op from dequant (int8 to float) to quant (int8 to uint8)
243     OperatorT* op = subgraph->operators[tot.op_index].get();
244     op->opcode_index = quant_op_index;
245   }
246   return kTfLiteOk;
247 }
248 
RemoveInputTensor(ModelT * model,const std::vector<TensorOpTensor> & inputs,int32 original_number_tensors)249 TfLiteStatus RemoveInputTensor(ModelT* model,
250                                const std::vector<TensorOpTensor>& inputs,
251                                int32 original_number_tensors) {
252   // Consistency check to make sure that erase start from the end.
253   int last_op_index = std::numeric_limits<int32_t>::max();
254   int last_tensor_index = std::numeric_limits<int32_t>::max();
255   for (auto tot : inputs) {
256     TFLITE_DCHECK(tot.input_index < last_tensor_index);
257     TFLITE_DCHECK(tot.op_index < last_op_index);
258     last_tensor_index = tot.input_index;
259     last_op_index = tot.op_index;
260   }
261   // Removes the input tensor and the related operator.
262   for (auto tot : inputs) {
263     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
264     TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
265     TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
266     if (tot.input_index >= original_number_tensors) {
267       subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
268     }
269     subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
270     subgraph->inputs[tot.model_index] = tot.output_index;
271   }
272   return kTfLiteOk;
273 }
274 
RemoveOutputTensor(ModelT * model,const std::vector<TensorOpTensor> & outputs,int32 original_number_tensors)275 TfLiteStatus RemoveOutputTensor(ModelT* model,
276                                 const std::vector<TensorOpTensor>& outputs,
277                                 int32 original_number_tensors) {
278   // Consistency check to make sure that erase start from the end.
279   int last_op_index = std::numeric_limits<int32_t>::max();
280   int last_tensor_index = std::numeric_limits<int32_t>::max();
281   for (auto tot : outputs) {
282     TFLITE_DCHECK(tot.output_index < last_tensor_index);
283     TFLITE_DCHECK(tot.op_index < last_op_index);
284     last_tensor_index = tot.output_index;
285     last_op_index = tot.op_index;
286   }
287   // Removes the output tensor and the related operator.
288   for (auto tot : outputs) {
289     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
290     TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
291     TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
292     if (tot.output_index >= original_number_tensors) {
293       subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
294     }
295     subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
296     subgraph->outputs[tot.model_index] = tot.input_index;
297   }
298   return kTfLiteOk;
299 }
300 
301 
GetOriginalNumberOfTensors(const TensorType & input_type,const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)302 int GetOriginalNumberOfTensors(const TensorType& input_type,
303                                const TensorType& output_type, ModelT* model,
304                                ErrorReporter* error_reporter) {
305   std::vector<TensorOpTensor> outputs =
306       GetOutputTensors(output_type, model, error_reporter);
307   std::vector<TensorOpTensor> inputs =
308       GetInputTensors(input_type, model, error_reporter);
309   return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
310 }
311 
312 }  // namespace
313 
ModifyModelInterface(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type)314 TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
315                                   ModelT* model, const TensorType& input_type,
316                                   const TensorType& output_type) {
317   tflite::StderrReporter error_reporter;
318   const int original_number_tensors = GetOriginalNumberOfTensors(
319       input_type, output_type, model, &error_reporter);
320   // Finds float tensors that are model output and are consumed by a float to
321   // int8/int16 quantize Op. Do output first since the tensors are added into
322   // input first.,
323   std::vector<TensorOpTensor> outputs =
324       GetOutputTensors(output_type, model, &error_reporter);
325   switch (output_type) {
326     case TensorType_UINT8:
327       SetOutputTypeToUINT8(model, outputs);
328       break;
329     case TensorType_INT8:
330     case TensorType_INT16:
331       RemoveOutputTensor(model, outputs, original_number_tensors);
332       break;
333     default:
334       return kTfLiteError;
335   }
336 
337   // Find float tensors that are model input and is consumed by a float to
338   // int8/int16 quantize Op.
339   std::vector<TensorOpTensor> inputs =
340       GetInputTensors(input_type, model, &error_reporter);
341   switch (input_type) {
342     case TensorType_UINT8:
343       SetInputTypeToUINT8(model, inputs);
344       break;
345     case TensorType_INT8:
346     case TensorType_INT16:
347       RemoveInputTensor(model, inputs, original_number_tensors);
348       break;
349     default:
350       return kTfLiteError;
351   }
352 
353   // Write to builder.
354   flatbuffers::Offset<Model> output_model_location =
355       Model::Pack(*builder, model);
356   FinishModelBuffer(*builder, output_model_location);
357 
358   return kTfLiteOk;
359 }
360 
ModifyModelInterface(const string & input_file,const string & output_file,const TensorType & input_type,const TensorType & output_type)361 TfLiteStatus ModifyModelInterface(const string& input_file,
362                                   const string& output_file,
363                                   const TensorType& input_type,
364                                   const TensorType& output_type) {
365   // Consistency Check
366   if (input_type != tflite::TensorType_INT8 &&
367       input_type != tflite::TensorType_UINT8 &&
368       input_type != tflite::TensorType_INT16) {
369     return kTfLiteError;
370   }
371   if (output_type != tflite::TensorType_INT8 &&
372       output_type != tflite::TensorType_UINT8 &&
373       output_type != tflite::TensorType_INT16) {
374     return kTfLiteError;
375   }
376 
377   // Create model.
378   auto tflite_model = utils::CreateMutableModelFromFile(input_file);
379 
380   auto model_builder = utils::FinishModel(tflite_model.get());
381 
382   auto fixed_point_model_builder =
383       std::make_unique<flatbuffers::FlatBufferBuilder>();
384   flatbuffers::FlatBufferBuilder builder;
385 
386   auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
387                                      output_type);
388   TFLITE_DCHECK_EQ(status, kTfLiteOk);
389 
390   utils::WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());
391 
392   return kTfLiteOk;
393 }
394 
395 namespace {
AddUint8Dequant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)396 void AddUint8Dequant(
397     const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
398     ModelT* model) {
399   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
400        subgraph_idx++) {
401     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
402     // Add dequant to input tensors.
403     for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
404          input_idx++) {
405       const int32_t tensor_idx = subgraph->inputs[input_idx];
406       TensorT* tensor = subgraph->tensors[tensor_idx].get();
407       if (tensor->type != TensorType_FLOAT32) {
408         continue;
409       }
410       if (quant_params.find(tensor->name) != quant_params.end()) {
411         // Add uint8 tensor
412         const string added_tensor_name = tensor->name + "_uint8";
413         std::unique_ptr<TensorT> leading_op_input;
414         const std::pair<float, int32_t>& provided_quant_params =
415             quant_params.at(string(tensor->name));
416         utils::MakeTensorWithQuantParam(
417             added_tensor_name, tensor->shape, tensor->shape_signature,
418             TensorType_UINT8, provided_quant_params.first,
419             provided_quant_params.second, &leading_op_input);
420         const int32_t leading_op_input_idx = subgraph->tensors.size();
421         subgraph->tensors.push_back(std::move(leading_op_input));
422 
423         // Create the leading op, which is deqantize Op.
424         std::unique_ptr<OperatorT> leading_op;
425         utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx,
426                                       tensor_idx);
427 
428         // Insert the new op at the start of the model.
429         subgraph->operators.insert(subgraph->operators.begin(),
430                                    std::move(leading_op));
431       }
432     }
433   }
434 }
435 
AddUint8Quant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)436 void AddUint8Quant(
437     const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
438     ModelT* model) {
439   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
440        subgraph_idx++) {
441     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
442     // Add quant to output tensors.
443     for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
444          output_idx++) {
445       const int32_t tensor_idx = subgraph->outputs[output_idx];
446       TensorT* tensor = subgraph->tensors[tensor_idx].get();
447       if (tensor->type != TensorType_FLOAT32) {
448         continue;
449       }
450       if (quant_params.find(tensor->name) != quant_params.end()) {
451         // Add uint8 tensor
452         const string added_tensor_name = tensor->name + "_uint8";
453         std::unique_ptr<TensorT> tailing_op_output;
454         const std::pair<float, int32_t>& provided_quant_params =
455             quant_params.at(string(tensor->name));
456         utils::MakeTensorWithQuantParam(
457             added_tensor_name, tensor->shape, tensor->shape_signature,
458             TensorType_UINT8, provided_quant_params.first,
459             provided_quant_params.second, &tailing_op_output);
460         const int32_t tailing_op_output_idx = subgraph->tensors.size();
461         subgraph->tensors.push_back(std::move(tailing_op_output));
462 
463         // Create the tailing op, which is Qantize Op.
464         std::unique_ptr<OperatorT> tailing_op;
465         utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx,
466                                     tailing_op_output_idx);
467 
468         // Insert the new op at the end of the model.
469         subgraph->operators.push_back(std::move(tailing_op));
470       }
471     }
472   }
473 }
474 }  // namespace
475 
Uint8QuantizeModelInputsOutputs(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,const std::unordered_map<string,std::pair<float,int32_t>> & input_quant_params,const std::unordered_map<string,std::pair<float,int32_t>> & output_quant_params)476 TfLiteStatus Uint8QuantizeModelInputsOutputs(
477     flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
478     const std::unordered_map<string, std::pair<float, int32_t>>&
479         input_quant_params,
480     const std::unordered_map<string, std::pair<float, int32_t>>&
481         output_quant_params) {
482   std::unique_ptr<ModelT> model;
483   model.reset(input_model->UnPack());
484   // Add Dequant for inputs.
485   AddUint8Dequant(input_quant_params, model.get());
486 
487   // Add Quant for outputs.
488   AddUint8Quant(output_quant_params, model.get());
489 
490   // Output model.
491   flatbuffers::Offset<Model> output_model_location =
492       Model::Pack(*builder, model.get());
493   FinishModelBuffer(*builder, output_model_location);
494 
495   return kTfLiteOk;
496 }
497 
498 }  // namespace optimize
499 }  // namespace tflite
500