• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/modify_model_interface.h"
16 
17 #include <memory>
18 #include <sstream>
19 #include <unordered_set>
20 
21 #include "flatbuffers/flexbuffers.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/error_reporter.h"
25 #include "tensorflow/lite/kernels/internal/compatibility.h"
26 #include "tensorflow/lite/model.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_utils.h"
29 #include "tensorflow/lite/tools/optimize/model_utils.h"
30 
31 namespace tflite {
32 namespace optimize {
33 
34 namespace {
35 
36 // Structure to hold input tensor, op and output tensor.
37 // op must be either quantize or dequantize.
38 struct TensorOpTensor {
39   size_t subgraph_index;  // index of the subgraph.
40   int32_t input_index;    // index of the input tensor.
41   int32_t op_index;       // index of the op.
42   int32_t output_index;   // index of the output tensor.
43   int32_t model_index;    // index of the added tensor in the model.
44 };
45 
46 // Finds float tensors that are model inputs and is consumed by a quantize Op.
47 // The returned TensorOpTensor should have reverse order.
GetInputTensors(const TensorType & input_type,ModelT * model,ErrorReporter * error_reporter)48 std::vector<TensorOpTensor> GetInputTensors(const TensorType& input_type,
49                                             ModelT* model,
50                                             ErrorReporter* error_reporter) {
51   std::vector<TensorOpTensor> result;
52   // Get all input tensors.
53   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
54        subgraph_idx++) {
55     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
56     std::unordered_map<TensorT*, int> input_tensors;
57     for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
58          input_idx++) {
59       TensorT* tensor = subgraph->tensors[subgraph->inputs[input_idx]].get();
60       if (tensor->type == TensorType_FLOAT32) {
61         input_tensors.insert({tensor, input_idx});
62       }
63     }
64 
65     for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
66          op_idx--) {
67       OperatorT* op = subgraph->operators[op_idx].get();
68       const BuiltinOperator op_code =
69           GetBuiltinCode(model->operator_codes[op->opcode_index].get());
70       TensorT* input_tensor = subgraph->tensors[op->inputs[0]].get();
71       if (input_tensors.find(input_tensor) == input_tensors.end()) {
72         continue;
73       }
74       if (op_code != BuiltinOperator_QUANTIZE) {
75         // Currently only supports int8 and int16 quantized models.
76         TF_LITE_REPORT_ERROR(
77             error_reporter,
78             "modify_model_interface called on a model without quant/dequant.");
79         return {};
80       }
81       if (op->inputs.size() != 1) {
82         continue;
83       }
84       if (op->outputs.size() != 1) {
85         continue;
86       }
87       const int model_input_index = input_tensors[input_tensor];
88       TensorT* quant_output = subgraph->tensors[op->outputs[0]].get();
89       if (quant_output->type != TensorType_INT8 &&
90           quant_output->type != TensorType_INT16) {
91         TF_LITE_REPORT_ERROR(error_reporter,
92                              "modify_model_interface currently only supports "
93                              "int8 and int16 quantized models.");
94       }
95 
96       // The input type must be the same as the model quantization type
97       if (input_type != quant_output->type) {
98         // An exception, allow for UINT8 input type for INT8 quantized model.
99         if (!(input_type == TensorType_UINT8 &&
100               quant_output->type == TensorType_INT8)) {
101           TF_LITE_REPORT_ERROR(
102               error_reporter,
103               "The %s input type is incompatible with %s quantized models. "
104               "To resolve this error, change the input_type to a compatible "
105               "one. "
106               "See: modify_model_interface.cc",
107               EnumNameTensorType(input_type),
108               EnumNameTensorType(quant_output->type));
109         }
110       }
111       if (quant_output->quantization == nullptr) {
112         continue;
113       }
114       result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
115                         model_input_index});
116     }
117   }
118   return result;
119 }
120 
121 // Finds float tensors that are model output and is consumed by a dequantize Op.
122 // The returned TensorOpTensor should have reverse order.
GetOutputTensors(const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)123 std::vector<TensorOpTensor> GetOutputTensors(const TensorType& output_type,
124                                              ModelT* model,
125                                              ErrorReporter* error_reporter) {
126   std::vector<TensorOpTensor> result;
127   // Get all output tensors.
128   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
129        subgraph_idx++) {
130     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
131     std::unordered_map<TensorT*, int> output_tensors;
132     for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
133          output_idx++) {
134       TensorT* tensor = subgraph->tensors[subgraph->outputs[output_idx]].get();
135       if (tensor->type == TensorType_FLOAT32) {
136         output_tensors.insert({tensor, output_idx});
137       }
138     }
139 
140     for (int32_t op_idx = subgraph->operators.size() - 1; op_idx >= 0;
141          op_idx--) {
142       OperatorT* op = subgraph->operators[op_idx].get();
143       const BuiltinOperator op_code =
144           GetBuiltinCode(model->operator_codes[op->opcode_index].get());
145       TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get();
146       if (output_tensors.find(output_tensor) == output_tensors.end()) {
147         continue;
148       }
149       if (op_code != BuiltinOperator_DEQUANTIZE) {
150         // Currently only supports int8 and int16 quantized models.
151         TF_LITE_REPORT_ERROR(
152             error_reporter,
153             "modify_model_interface called on a model without quant/dequant.");
154         return {};
155       }
156       if (op->inputs.size() != 1) {
157         continue;
158       }
159       if (op->outputs.size() != 1) {
160         continue;
161       }
162       const int model_output_index = output_tensors[output_tensor];
163       TensorT* dequant_input = subgraph->tensors[op->inputs[0]].get();
164       if (dequant_input->type != TensorType_INT8 &&
165           dequant_input->type != TensorType_INT16) {
166         // Currently only supports int8 and int16 quantized models.
167         TF_LITE_REPORT_ERROR(error_reporter,
168                              "modify_model_interface currently only supports "
169                              "int8 and int16 quantized models.");
170         return {};
171       }
172       if (output_type != dequant_input->type) {
173         // An exception, allow for UINT8 input type for INT8 quantized model.
174         if (!(output_type == TensorType_UINT8 &&
175               dequant_input->type == TensorType_INT8)) {
176           TF_LITE_REPORT_ERROR(
177               error_reporter,
178               "The %s output type is incompatible with %s quantized models. "
179               "To resolve this error, change the output_type to a compatible "
180               "one. "
181               "See: modify_model_interface.cc",
182               EnumNameTensorType(output_type),
183               EnumNameTensorType(dequant_input->type));
184         }
185       }
186       if (dequant_input->quantization == nullptr) {
187         continue;
188       }
189       result.push_back({subgraph_idx, op->inputs[0], op_idx, op->outputs[0],
190                         model_output_index});
191     }
192   }
193   return result;
194 }
195 
SetInputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & inputs)196 TfLiteStatus SetInputTypeToUINT8(ModelT* model,
197                                  const std::vector<TensorOpTensor>& inputs) {
198   // If the input type is uint8, change float to uint8.
199   for (auto tot : inputs) {
200     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
201     TensorT* quant_tensor = subgraph->tensors[tot.output_index].get();
202     const float quant_tensor_scale = quant_tensor->quantization->scale[0];
203     const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
204     TensorT* float_tensor = subgraph->tensors[tot.input_index].get();
205     float_tensor->type = TensorType_UINT8;
206     if (float_tensor->quantization == nullptr) {
207       float_tensor->quantization = absl::make_unique<QuantizationParametersT>();
208     }
209     float_tensor->quantization->scale.push_back(quant_tensor_scale);
210     float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
211   }
212   return kTfLiteOk;
213 }
214 
SetOutputTypeToUINT8(ModelT * model,const std::vector<TensorOpTensor> & outputs)215 TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
216                                   const std::vector<TensorOpTensor>& outputs) {
217   // Find Quant op code index.
218   size_t quant_op_index = 0;
219   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
220     if (GetBuiltinCode(model->operator_codes[i].get()) ==
221         BuiltinOperator_QUANTIZE) {
222       quant_op_index = i;
223     }
224   }
225   // If the output type is uint8, change float to uint8.
226   for (auto tot : outputs) {
227     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
228     TensorT* quant_tensor = subgraph->tensors[tot.input_index].get();
229     const float quant_tensor_scale = quant_tensor->quantization->scale[0];
230     const int quant_tensor_zp = quant_tensor->quantization->zero_point[0];
231     TensorT* float_tensor = subgraph->tensors[tot.output_index].get();
232     float_tensor->type = TensorType_UINT8;
233     if (float_tensor->quantization == nullptr) {
234       float_tensor->quantization = absl::make_unique<QuantizationParametersT>();
235     }
236     float_tensor->quantization->scale.push_back(quant_tensor_scale);
237     float_tensor->quantization->zero_point.push_back(quant_tensor_zp + 128);
238 
239     // Change op from dequant (int8 to float) to quant (int8 to uint8)
240     OperatorT* op = subgraph->operators[tot.op_index].get();
241     op->opcode_index = quant_op_index;
242   }
243   return kTfLiteOk;
244 }
245 
RemoveInputTensor(ModelT * model,const std::vector<TensorOpTensor> & inputs,int32 original_number_tensors)246 TfLiteStatus RemoveInputTensor(ModelT* model,
247                                const std::vector<TensorOpTensor>& inputs,
248                                int32 original_number_tensors) {
249   // Consistency check to make sure that erase start from the end.
250   int last_op_index = std::numeric_limits<int32_t>::max();
251   int last_tensor_index = std::numeric_limits<int32_t>::max();
252   for (auto tot : inputs) {
253     TFLITE_DCHECK(tot.input_index < last_tensor_index);
254     TFLITE_DCHECK(tot.op_index < last_op_index);
255     last_tensor_index = tot.input_index;
256     last_op_index = tot.op_index;
257   }
258   // Removes the input tensor and the related operator.
259   for (auto tot : inputs) {
260     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
261     TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
262     TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
263     if (tot.input_index >= original_number_tensors) {
264       subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
265     }
266     subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
267     subgraph->inputs[tot.model_index] = tot.output_index;
268   }
269   return kTfLiteOk;
270 }
271 
RemoveOutputTensor(ModelT * model,const std::vector<TensorOpTensor> & outputs,int32 original_number_tensors)272 TfLiteStatus RemoveOutputTensor(ModelT* model,
273                                 const std::vector<TensorOpTensor>& outputs,
274                                 int32 original_number_tensors) {
275   // Consistency check to make sure that erase start from the end.
276   int last_op_index = std::numeric_limits<int32_t>::max();
277   int last_tensor_index = std::numeric_limits<int32_t>::max();
278   for (auto tot : outputs) {
279     TFLITE_DCHECK(tot.output_index < last_tensor_index);
280     TFLITE_DCHECK(tot.op_index < last_op_index);
281     last_tensor_index = tot.output_index;
282     last_op_index = tot.op_index;
283   }
284   // Removes the output tensor and the related operator.
285   for (auto tot : outputs) {
286     SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
287     TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
288     TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
289     if (tot.output_index >= original_number_tensors) {
290       subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
291     }
292     subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
293     subgraph->outputs[tot.model_index] = tot.input_index;
294   }
295   return kTfLiteOk;
296 }
297 
298 
GetOriginalNumberOfTensors(const TensorType & input_type,const TensorType & output_type,ModelT * model,ErrorReporter * error_reporter)299 int GetOriginalNumberOfTensors(const TensorType& input_type,
300                                const TensorType& output_type, ModelT* model,
301                                ErrorReporter* error_reporter) {
302   std::vector<TensorOpTensor> outputs =
303       GetOutputTensors(output_type, model, error_reporter);
304   std::vector<TensorOpTensor> inputs =
305       GetInputTensors(input_type, model, error_reporter);
306   return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
307 }
308 
309 }  // namespace
310 
ModifyModelInterface(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type)311 TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
312                                   ModelT* model, const TensorType& input_type,
313                                   const TensorType& output_type) {
314   tflite::StderrReporter error_reporter;
315   const int original_number_tensors = GetOriginalNumberOfTensors(
316       input_type, output_type, model, &error_reporter);
317   // Finds float tensors that are model output and are consumed by a float to
318   // int8/int16 quantize Op. Do output first since the tensors are added into
319   // input first.,
320   std::vector<TensorOpTensor> outputs =
321       GetOutputTensors(output_type, model, &error_reporter);
322   switch (output_type) {
323     case TensorType_UINT8:
324       SetOutputTypeToUINT8(model, outputs);
325       break;
326     case TensorType_INT8:
327     case TensorType_INT16:
328       RemoveOutputTensor(model, outputs, original_number_tensors);
329       break;
330     default:
331       return kTfLiteError;
332   }
333 
334   // Find float tensors that are model input and is consumed by a float to
335   // int8/int16 quantize Op.
336   std::vector<TensorOpTensor> inputs =
337       GetInputTensors(input_type, model, &error_reporter);
338   switch (input_type) {
339     case TensorType_UINT8:
340       SetInputTypeToUINT8(model, inputs);
341       break;
342     case TensorType_INT8:
343     case TensorType_INT16:
344       RemoveInputTensor(model, inputs, original_number_tensors);
345       break;
346     default:
347       return kTfLiteError;
348   }
349 
350   // Write to builder.
351   flatbuffers::Offset<Model> output_model_location =
352       Model::Pack(*builder, model);
353   FinishModelBuffer(*builder, output_model_location);
354 
355   return kTfLiteOk;
356 }
357 
ModifyModelInterface(const string & input_file,const string & output_file,const TensorType & input_type,const TensorType & output_type)358 TfLiteStatus ModifyModelInterface(const string& input_file,
359                                   const string& output_file,
360                                   const TensorType& input_type,
361                                   const TensorType& output_type) {
362   // Consistency Check
363   if (input_type != tflite::TensorType_INT8 &&
364       input_type != tflite::TensorType_UINT8 &&
365       input_type != tflite::TensorType_INT16) {
366     return kTfLiteError;
367   }
368   if (output_type != tflite::TensorType_INT8 &&
369       output_type != tflite::TensorType_UINT8 &&
370       output_type != tflite::TensorType_INT16) {
371     return kTfLiteError;
372   }
373 
374   // Create model.
375   auto tflite_model = utils::CreateMutableModelFromFile(input_file);
376 
377   auto model_builder = utils::FinishModel(tflite_model.get());
378 
379   auto fixed_point_model_builder =
380       absl::make_unique<flatbuffers::FlatBufferBuilder>();
381   flatbuffers::FlatBufferBuilder builder;
382 
383   auto status = ModifyModelInterface(&builder, tflite_model.get(), input_type,
384                                      output_type);
385   TFLITE_DCHECK_EQ(status, kTfLiteOk);
386 
387   utils::WriteFile(output_file, builder.GetBufferPointer(), builder.GetSize());
388 
389   return kTfLiteOk;
390 }
391 
392 namespace {
AddUint8Dequant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)393 void AddUint8Dequant(
394     const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
395     ModelT* model) {
396   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
397        subgraph_idx++) {
398     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
399     // Add dequant to input tensors.
400     for (size_t input_idx = 0; input_idx < subgraph->inputs.size();
401          input_idx++) {
402       const int32_t tensor_idx = subgraph->inputs[input_idx];
403       TensorT* tensor = subgraph->tensors[tensor_idx].get();
404       if (tensor->type != TensorType_FLOAT32) {
405         continue;
406       }
407       if (quant_params.find(tensor->name) != quant_params.end()) {
408         // Add uint8 tensor
409         const string added_tensor_name = tensor->name + "_uint8";
410         std::unique_ptr<TensorT> leading_op_input;
411         const std::pair<float, int32_t>& provided_quant_params =
412             quant_params.at(string(tensor->name));
413         utils::MakeTensorWithQuantParam(
414             added_tensor_name, tensor->shape, tensor->shape_signature,
415             TensorType_UINT8, provided_quant_params.first,
416             provided_quant_params.second, &leading_op_input);
417         const int32_t leading_op_input_idx = subgraph->tensors.size();
418         subgraph->tensors.push_back(std::move(leading_op_input));
419 
420         // Create the leading op, which is deqantize Op.
421         std::unique_ptr<OperatorT> leading_op;
422         utils::MakeDequantizeOperator(model, &leading_op, leading_op_input_idx,
423                                       tensor_idx);
424 
425         // Insert the new op at the start of the model.
426         subgraph->operators.insert(subgraph->operators.begin(),
427                                    std::move(leading_op));
428       }
429     }
430   }
431 }
432 
AddUint8Quant(const std::unordered_map<string,std::pair<float,int32_t>> & quant_params,ModelT * model)433 void AddUint8Quant(
434     const std::unordered_map<string, std::pair<float, int32_t>>& quant_params,
435     ModelT* model) {
436   for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
437        subgraph_idx++) {
438     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
439     // Add quant to output tensors.
440     for (size_t output_idx = 0; output_idx < subgraph->outputs.size();
441          output_idx++) {
442       const int32_t tensor_idx = subgraph->outputs[output_idx];
443       TensorT* tensor = subgraph->tensors[tensor_idx].get();
444       if (tensor->type != TensorType_FLOAT32) {
445         continue;
446       }
447       if (quant_params.find(tensor->name) != quant_params.end()) {
448         // Add uint8 tensor
449         const string added_tensor_name = tensor->name + "_uint8";
450         std::unique_ptr<TensorT> tailing_op_output;
451         const std::pair<float, int32_t>& provided_quant_params =
452             quant_params.at(string(tensor->name));
453         utils::MakeTensorWithQuantParam(
454             added_tensor_name, tensor->shape, tensor->shape_signature,
455             TensorType_UINT8, provided_quant_params.first,
456             provided_quant_params.second, &tailing_op_output);
457         const int32_t tailing_op_output_idx = subgraph->tensors.size();
458         subgraph->tensors.push_back(std::move(tailing_op_output));
459 
460         // Create the tailing op, which is Qantize Op.
461         std::unique_ptr<OperatorT> tailing_op;
462         utils::MakeQuantizeOperator(model, &tailing_op, tensor_idx,
463                                     tailing_op_output_idx);
464 
465         // Insert the new op at the end of the model.
466         subgraph->operators.push_back(std::move(tailing_op));
467       }
468     }
469   }
470 }
471 }  // namespace
472 
Uint8QuantizeModelInputsOutputs(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,const std::unordered_map<string,std::pair<float,int32_t>> & input_quant_params,const std::unordered_map<string,std::pair<float,int32_t>> & output_quant_params)473 TfLiteStatus Uint8QuantizeModelInputsOutputs(
474     flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
475     const std::unordered_map<string, std::pair<float, int32_t>>&
476         input_quant_params,
477     const std::unordered_map<string, std::pair<float, int32_t>>&
478         output_quant_params) {
479   std::unique_ptr<ModelT> model;
480   model.reset(input_model->UnPack());
481   // Add Dequant for inputs.
482   AddUint8Dequant(input_quant_params, model.get());
483 
484   // Add Quant for outputs.
485   AddUint8Quant(output_quant_params, model.get());
486 
487   // Output model.
488   flatbuffers::Offset<Model> output_model_location =
489       Model::Pack(*builder, model.get());
490   FinishModelBuffer(*builder, output_model_location);
491 
492   return kTfLiteOk;
493 }
494 
495 }  // namespace optimize
496 }  // namespace tflite
497