• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/model_utils.h"
16 
17 #include <fstream>
18 #include <memory>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/model.h"
24 #include "tensorflow/lite/schema/schema_conversion_utils.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow/lite/schema/schema_utils.h"
27 #include "tensorflow/lite/tools/optimize/operator_property.h"
28 
29 namespace tflite {
30 namespace optimize {
31 namespace utils {
32 
33 namespace {
34 
35 // Returns the index of the OpCode.
36 // If a OpCode doesn't exist, adds it and returns its index.
GetOrInsertOpCodeIndex(ModelT * model,const BuiltinOperator & op_code,int32_t version)37 int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
38                                int32_t version) {
39   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
40     if (GetBuiltinCode(model->operator_codes[i].get()) == op_code) {
41       return i;
42     }
43   }
44   model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
45   int op_code_idx = model->operator_codes.size() - 1;
46   model->operator_codes[op_code_idx]->builtin_code = op_code;
47   model->operator_codes[op_code_idx]->deprecated_builtin_code =
48       ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code);
49   // Version 2 and onwards supports INT8 inputs.
50   model->operator_codes[op_code_idx]->version = version;
51 
52   // Return the index of the newly placed OperatorCodeT.
53   return op_code_idx;
54 }
55 
56 }  // namespace
57 
58 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)59 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
60                             int32_t input, int32_t output) {
61   OperatorT* op_raw = new OperatorT;
62   // Version 2 and onwards supports INT8 inputs.
63   op_raw->opcode_index =
64       GetOrInsertOpCodeIndex(model, BuiltinOperator_DEQUANTIZE, 2);
65   op_raw->inputs = {input};
66   op_raw->outputs = {output};
67 
68   op->reset(op_raw);
69 }
70 
71 // Creates a Quantize OperatorT object.
MakeQuantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)72 void MakeQuantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
73                           int32_t input, int32_t output) {
74   OperatorT* op_raw = new OperatorT;
75   op_raw->opcode_index =
76       GetOrInsertOpCodeIndex(model, BuiltinOperator_QUANTIZE, 1);
77   op_raw->inputs = {input};
78   op_raw->outputs = {output};
79 
80   op->reset(op_raw);
81 }
82 
83 // Create a new TensorT object without quantization parameters.
MakeTensor(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,std::unique_ptr<TensorT> * tensor)84 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
85                 const std::vector<int32_t>& shape_signature,
86                 const TensorType& type, std::unique_ptr<TensorT>* tensor) {
87   TensorT* tensor_raw = new TensorT;
88   tensor_raw->name = name;
89   tensor_raw->shape = shape;
90   if (!shape_signature.empty()) {
91     tensor_raw->shape_signature = shape_signature;
92   }
93   tensor_raw->type = type;
94 
95   tensor->reset(tensor_raw);
96 }
97 
98 // Create a new TensorT object with quantization parameters.
MakeTensorWithQuantParam(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,float scale,int64_t zero_point,std::unique_ptr<TensorT> * tensor)99 void MakeTensorWithQuantParam(const string& name,
100                               const std::vector<int32_t>& shape,
101                               const std::vector<int32_t>& shape_signature,
102                               const TensorType& type, float scale,
103                               int64_t zero_point,
104                               std::unique_ptr<TensorT>* tensor) {
105   MakeTensor(name, shape, shape_signature, type, tensor);
106   (*tensor)->quantization = absl::make_unique<QuantizationParametersT>();
107   (*tensor)->quantization->scale.push_back(scale);
108   (*tensor)->quantization->zero_point.push_back(zero_point);
109 }
110 
QuantizationParametersExist(const TensorT * tensor)111 bool QuantizationParametersExist(const TensorT* tensor) {
112   return tensor->quantization != nullptr &&
113          !tensor->quantization->scale.empty() &&
114          !tensor->quantization->zero_point.empty();
115 }
116 
HasBuffer(const ModelT * model,const SubGraphT * subgraph,int tensor_index)117 bool HasBuffer(const ModelT* model, const SubGraphT* subgraph,
118                int tensor_index) {
119   const int buffer_index = subgraph->tensors[tensor_index]->buffer;
120   BufferT* buffer = model->buffers[buffer_index].get();
121   if (buffer == nullptr || buffer->data.empty()) {
122     return false;
123   }
124   return true;
125 }
126 
HasMinMax(const TensorT * tensor)127 bool HasMinMax(const TensorT* tensor) {
128   return tensor->quantization && !tensor->quantization->min.empty() &&
129          !tensor->quantization->max.empty();
130 }
131 
SetOperatorCodeVersion(ModelT * model)132 void SetOperatorCodeVersion(ModelT* model) {
133   for (int subgraph_idx = 0, end = model->subgraphs.size(); subgraph_idx < end;
134        subgraph_idx++) {
135     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
136     // Iterate backward to avoid messing with index.
137     for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
138       OperatorT* op = subgraph->operators[op_idx].get();
139       OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
140       operator_property::OperatorProperty property =
141           operator_property::GetOperatorProperty(model, subgraph_idx, op_idx);
142       if (property.quantizable && op_code->version < property.version) {
143         // Only update the versions of quantizable operations if the original
144         // version is lesser than minimum quantized one mentioned by
145         // OperatorProperty.
146         op_code->version = property.version;
147       }
148     }
149   }
150 }
151 
WriteFile(const std::string & out_file,const uint8_t * bytes,size_t num_bytes)152 void WriteFile(const std::string& out_file, const uint8_t* bytes,
153                size_t num_bytes) {
154   std::fstream stream(out_file, std::ios::binary | std::ios::out);
155   for (size_t i = 0; i < num_bytes; i++) {
156     stream << bytes[i];
157   }
158   TFLITE_DCHECK(!stream.bad() && !stream.fail());
159 }
160 
FinishModel(const tflite::ModelT * model)161 std::unique_ptr<flatbuffers::FlatBufferBuilder> FinishModel(
162     const tflite::ModelT* model) {
163   std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
164       new flatbuffers::FlatBufferBuilder());
165   auto packed_model = tflite::Model::Pack(*builder, model);
166   tflite::FinishModelBuffer(*builder, packed_model);
167   return builder;
168 }
169 
CreateMutableModelFromFile(const string & model_filepath)170 std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
171     const string& model_filepath) {
172   auto fb_model =
173       tflite::FlatBufferModel::BuildFromFile(model_filepath.c_str());
174   auto tflite_model = fb_model->GetModel();
175   auto copied_model = absl::make_unique<tflite::ModelT>();
176   tflite_model->UnPackTo(copied_model.get(), nullptr);
177   return copied_model;
178 }
179 
180 }  // namespace utils
181 }  // namespace optimize
182 }  // namespace tflite
183