• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/model_utils.h"
16 
17 #include <memory>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/model.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow/lite/tools/optimize/operator_property.h"
25 
26 namespace tflite {
27 namespace optimize {
28 namespace utils {
29 
30 namespace {
31 
32 // Returns the index of the OpCode.
33 // If a OpCode doesn't exist, adds it and returns its index.
GetOrInsertOpCodeIndex(ModelT * model,const BuiltinOperator & op_code,int32_t version)34 int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
35                                int32_t version) {
36   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
37     if (model->operator_codes[i]->builtin_code == op_code) {
38       return i;
39     }
40   }
41   model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
42   int op_code_idx = model->operator_codes.size() - 1;
43   model->operator_codes[op_code_idx]->builtin_code = op_code;
44   // Version 2 and onwards supports INT8 inputs.
45   model->operator_codes[op_code_idx]->version = version;
46 
47   // Return the index of the newly placed OperatorCodeT.
48   return op_code_idx;
49 }
50 
51 }  // namespace
52 
53 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)54 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
55                             int32_t input, int32_t output) {
56   OperatorT* op_raw = new OperatorT;
57   // Version 2 and onwards supports INT8 inputs.
58   op_raw->opcode_index =
59       GetOrInsertOpCodeIndex(model, BuiltinOperator_DEQUANTIZE, 2);
60   op_raw->inputs = {input};
61   op_raw->outputs = {output};
62 
63   op->reset(op_raw);
64 }
65 
66 // Creates a Quantize OperatorT object.
MakeQuantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)67 void MakeQuantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
68                           int32_t input, int32_t output) {
69   OperatorT* op_raw = new OperatorT;
70   op_raw->opcode_index =
71       GetOrInsertOpCodeIndex(model, BuiltinOperator_QUANTIZE, 1);
72   op_raw->inputs = {input};
73   op_raw->outputs = {output};
74 
75   op->reset(op_raw);
76 }
77 
78 // Create a new TensorT object without quantization parameters.
MakeTensor(const string & name,const std::vector<int32_t> & shape,const TensorType & type,std::unique_ptr<TensorT> * tensor)79 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
80                 const TensorType& type, std::unique_ptr<TensorT>* tensor) {
81   TensorT* tensor_raw = new TensorT;
82   tensor_raw->name = name;
83   tensor_raw->shape = shape;
84   tensor_raw->type = type;
85 
86   tensor->reset(tensor_raw);
87 }
88 
89 // Create a new TensorT object with quantization parameters.
MakeTensorWithQuantParam(const string & name,const std::vector<int32_t> & shape,const TensorType & type,float scale,int64_t zero_point,std::unique_ptr<TensorT> * tensor)90 void MakeTensorWithQuantParam(const string& name,
91                               const std::vector<int32_t>& shape,
92                               const TensorType& type, float scale,
93                               int64_t zero_point,
94                               std::unique_ptr<TensorT>* tensor) {
95   MakeTensor(name, shape, type, tensor);
96   (*tensor)->quantization = absl::make_unique<QuantizationParametersT>();
97   (*tensor)->quantization->scale.push_back(scale);
98   (*tensor)->quantization->zero_point.push_back(zero_point);
99 }
100 
QuantizationParametersExist(const TensorT * tensor)101 bool QuantizationParametersExist(const TensorT* tensor) {
102   return tensor->quantization != nullptr &&
103          !tensor->quantization->scale.empty() &&
104          !tensor->quantization->zero_point.empty();
105 }
106 
HasBuffer(const ModelT * model,const SubGraphT * subgraph,int tensor_index)107 bool HasBuffer(const ModelT* model, const SubGraphT* subgraph,
108                int tensor_index) {
109   const int buffer_index = subgraph->tensors[tensor_index]->buffer;
110   BufferT* buffer = model->buffers[buffer_index].get();
111   if (buffer == nullptr || buffer->data.empty()) {
112     return false;
113   }
114   return true;
115 }
116 
HasMinMax(const TensorT * tensor)117 bool HasMinMax(const TensorT* tensor) {
118   return tensor->quantization && !tensor->quantization->min.empty() &&
119          !tensor->quantization->max.empty();
120 }
121 
SetOperatorCodeVersion(ModelT * model)122 void SetOperatorCodeVersion(ModelT* model) {
123   for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
124        subgraph_idx++) {
125     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
126     // Iterate backward to avoid messing with index.
127     for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
128       OperatorT* op = subgraph->operators[op_idx].get();
129       OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
130       operator_property::OperatorProperty property =
131           operator_property::GetOperatorProperty(model, subgraph_idx, op_idx);
132       if (property.quantizable) {
133         // Only update the versions of quantizable operations.
134         op_code->version = property.version;
135       }
136     }
137   }
138 }
139 
140 }  // namespace utils
141 }  // namespace optimize
142 }  // namespace tflite
143