1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16
17 #include <cstddef>
18 #include <cstdint>
19 #include <memory>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/util/command_line_flags.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 #include "tensorflow/lite/tools/optimize/test_util.h"
32
33 namespace {
34 tensorflow::string* g_test_model_dir = nullptr;
35 } // namespace
36
37 namespace tflite {
38 namespace optimize {
39 namespace {
40
ReadTestModel()41 std::unique_ptr<FlatBufferModel> ReadTestModel() {
42 auto model_path = tensorflow::io::JoinPath(
43 *g_test_model_dir, internal::kConvModelWith0Plus10Weights);
44 return FlatBufferModel::BuildFromFile(model_path.c_str());
45 }
46
ReadSharedWeightsTestModel()47 std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
48 auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
49 internal::kModelWithSharedWeights);
50 return FlatBufferModel::BuildFromFile(model_path.c_str());
51 }
52
ReadGatherTestModel()53 std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
54 auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
55 internal::kQuantizedWithGather);
56 return FlatBufferModel::BuildFromFile(model_path.c_str());
57 }
58
ReadCustomOpTestModel()59 std::unique_ptr<FlatBufferModel> ReadCustomOpTestModel() {
60 auto model_path =
61 tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp);
62 return FlatBufferModel::BuildFromFile(model_path.c_str());
63 }
64
65 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)66 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
67 return std::vector<T>(vec->begin(), vec->end());
68 }
69
70 class QuantizeWeightsTest : public testing::Test {
71 protected:
QuantizeWeightsTest()72 QuantizeWeightsTest() {}
73
LoadBasicModel()74 void LoadBasicModel() {
75 input_model_ = ReadTestModel();
76 model_ = input_model_->GetModel();
77 }
78
LoadSharedWeightsModel()79 void LoadSharedWeightsModel() {
80 input_model_ = ReadSharedWeightsTestModel();
81 model_ = input_model_->GetModel();
82 }
83
LoadGatherTestModel()84 void LoadGatherTestModel() {
85 input_model_ = ReadGatherTestModel();
86 model_ = input_model_->GetModel();
87 }
88
LoadCustomOpTestModel()89 void LoadCustomOpTestModel() {
90 input_model_ = ReadCustomOpTestModel();
91 model_ = input_model_->GetModel();
92 }
93
94 std::unique_ptr<FlatBufferModel> input_model_;
95 const Model* model_;
96
IsModelInputOrOutput(const Model * model,uint32_t tensor_idx)97 bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) {
98 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
99 ++subgraph_idx) {
100 const auto subgraph = model->subgraphs()->Get(subgraph_idx);
101 for (size_t i = 0; i < subgraph->inputs()->size(); ++i) {
102 if (subgraph->inputs()->Get(i) == tensor_idx) {
103 return true;
104 }
105 }
106 for (size_t i = 0; i < subgraph->outputs()->size(); ++i) {
107 if (subgraph->outputs()->Get(i) == tensor_idx) {
108 return true;
109 }
110 }
111 }
112 return false;
113 }
114
115 // Returns the producer op code of the specified tensor_idx.
GetProducerOpCode(const Model * model,uint32_t subgraph_idx,uint32_t tensor_idx,tflite::BuiltinOperator * op_code)116 bool GetProducerOpCode(const Model* model, uint32_t subgraph_idx,
117 uint32_t tensor_idx,
118 tflite::BuiltinOperator* op_code) {
119 const auto subgraph = model->subgraphs()->Get(subgraph_idx);
120 for (size_t op_idx = 0; op_idx < subgraph->operators()->size(); ++op_idx) {
121 const auto op = subgraph->operators()->Get(op_idx);
122 for (size_t i = 0; i < op->outputs()->size(); ++i) {
123 if (op->outputs()->Get(i) == tensor_idx) {
124 const uint32_t op_code_idx = op->opcode_index();
125 *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx));
126 return true;
127 }
128 }
129 }
130 return false;
131 }
132 };
133
TEST_F(QuantizeWeightsTest,QuantizationSucceeds)134 TEST_F(QuantizeWeightsTest, QuantizationSucceeds) {
135 LoadBasicModel();
136 flatbuffers::FlatBufferBuilder builder;
137 auto status =
138 QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
139 EXPECT_EQ(status, kTfLiteOk);
140
141 const uint8_t* buffer = builder.GetBufferPointer();
142 const Model* output_model = GetModel(buffer);
143 ASSERT_TRUE(output_model);
144 }
145
TEST_F(QuantizeWeightsTest,WeightsMinNumElements)146 TEST_F(QuantizeWeightsTest, WeightsMinNumElements) {
147 LoadBasicModel();
148 // Make weights_min_size sufficiently large such that no quantization should
149 // happen, i.e. the original model is the same size as the old one.
150 flatbuffers::FlatBufferBuilder builder;
151 const uint64_t kWeightsMinNumElements = 1000000;
152 EXPECT_EQ(QuantizeWeights(&builder, model_, kWeightsMinNumElements,
153 QuantizerType::OLD_QUANTIZER),
154 kTfLiteOk);
155
156 const uint8_t* buffer = builder.GetBufferPointer();
157 const Model* output_model = GetModel(buffer);
158 ASSERT_TRUE(output_model);
159
160 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
161 subgraph_idx++) {
162 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
163 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
164 ASSERT_EQ(quantized_graph->tensors()->size(),
165 float_graph->tensors()->size());
166 for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
167 const auto quant_tensor = quantized_graph->tensors()->Get(i);
168 const auto float_tensor = float_graph->tensors()->Get(i);
169 // Everything should remain equal between the two graphs.
170 EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
171 EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
172 EXPECT_EQ(GetAsVector(quant_tensor->shape()),
173 GetAsVector(float_tensor->shape()));
174 EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
175 EXPECT_EQ(quant_tensor->type(), float_tensor->type());
176 }
177 }
178 }
179
TEST_F(QuantizeWeightsTest,HybridConv)180 TEST_F(QuantizeWeightsTest, HybridConv) {
181 LoadBasicModel();
182 flatbuffers::FlatBufferBuilder builder;
183 auto status =
184 QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
185 EXPECT_EQ(status, kTfLiteOk);
186
187 const uint8_t* buffer = builder.GetBufferPointer();
188 const Model* output_model = GetModel(buffer);
189 ASSERT_TRUE(output_model);
190
191 // Nothing should change.
192 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
193 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
194 subgraph_idx++) {
195 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
196 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
197 ASSERT_EQ(quantized_graph->tensors()->size(),
198 float_graph->tensors()->size());
199 // Make sure the graph only has one Conv operation.
200 ASSERT_EQ(quantized_graph->operators()->size(), 1);
201 const auto op = quantized_graph->operators()->Get(0);
202 const uint32_t op_code_idx = op->opcode_index();
203 ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
204 BuiltinOperator_CONV_2D);
205 for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
206 const auto quant_tensor = quantized_graph->tensors()->Get(i);
207 const auto float_tensor = float_graph->tensors()->Get(i);
208 EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
209 EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
210 EXPECT_EQ(GetAsVector(quant_tensor->shape()),
211 GetAsVector(float_tensor->shape()));
212 EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
213 // If the tensor is a weight, it should have type INT8, otherwise it
214 // should stay with type FLOAT32.
215 // If the tensor is a bias, it should have type FLOAT32.
216 if (quant_tensor->name()->str() == "conv_bias") {
217 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
218 } else if (IsModelInputOrOutput(output_model, i)) {
219 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
220 } else if (quant_tensor->buffer() != 0) {
221 EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
222 << quant_tensor->name()->str();
223 auto shape = GetAsVector(quant_tensor->shape());
224 if (kUseUpdatedHybridSchemeDefault) {
225 EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
226 } else {
227 EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
228 }
229 } else {
230 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
231 }
232 }
233 }
234 }
235
TEST_F(QuantizeWeightsTest,DequantizeConv)236 TEST_F(QuantizeWeightsTest, DequantizeConv) {
237 LoadBasicModel();
238 flatbuffers::FlatBufferBuilder builder;
239 auto status = internal::QuantizeWeights(&builder, model_, 0,
240 /*use_hybrid_evaluation=*/false,
241 QuantizerType::OLD_QUANTIZER);
242 EXPECT_EQ(status, kTfLiteOk);
243
244 const uint8_t* buffer = builder.GetBufferPointer();
245 const Model* output_model = GetModel(buffer);
246 ASSERT_TRUE(output_model);
247
248 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
249 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
250 ++subgraph_idx) {
251 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
252 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
253 // The output graph should have an extra tensor from the added dequantize
254 // op.
255 ASSERT_EQ(quantized_graph->tensors()->size(),
256 float_graph->tensors()->size() + 1);
257 // Check that a dequantize op exists.
258 int32_t dequant_input_idx = -1;
259 int32_t dequant_output_idx = -1;
260 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
261 const auto op = quantized_graph->operators()->Get(i);
262 const uint32_t op_code_idx = op->opcode_index();
263 if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
264 BuiltinOperator_DEQUANTIZE) {
265 dequant_input_idx = op->inputs()->Get(0);
266 dequant_output_idx = op->outputs()->Get(0);
267 }
268 }
269 ASSERT_GT(dequant_input_idx, -1);
270 ASSERT_GT(dequant_output_idx, -1);
271 for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
272 const auto quant_tensor = quantized_graph->tensors()->Get(i);
273 // If the tensor is a weight, it should have type INT8.
274 // If the tensor is a bias, it should have type FLOAT32.
275 // If the tensor is an input or output it should have type FLOAT32.
276 // The input to dequantize should be INT8, and all other tensors should be
277 // FLOAT32.
278 if (i == dequant_input_idx) {
279 EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
280 } else if (i == dequant_output_idx) {
281 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
282 } else if (IsModelInputOrOutput(output_model, i)) {
283 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
284 } else if (quant_tensor->name()->str() == "conv_bias") {
285 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
286 } else if (quant_tensor->buffer() != 0) {
287 // If it's a non-bias constant tensor, it must be the weight.
288 EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
289 } else {
290 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
291 }
292 }
293 }
294 }
295
TEST_F(QuantizeWeightsTest,DequantizeConvFloat16)296 TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) {
297 LoadBasicModel();
298 flatbuffers::FlatBufferBuilder builder;
299 auto status = tflite::optimize::QuantizeWeights(
300 &builder, model_, BufferType::QUANTIZED_FLOAT16,
301 kUseUpdatedHybridSchemeDefault, QuantizerType::OLD_QUANTIZER);
302 EXPECT_EQ(status, kTfLiteOk);
303
304 const uint8_t* buffer = builder.GetBufferPointer();
305 const Model* output_model = GetModel(buffer);
306 ASSERT_TRUE(output_model);
307
308 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
309 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
310 ++subgraph_idx) {
311 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
312 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
313 // The output graph should have two extra tensors from the added dequantize
314 // op.
315 ASSERT_EQ(quantized_graph->tensors()->size(),
316 float_graph->tensors()->size() + 2);
317 // Check that a dequantize op exists.
318 int32_t dequant_input_idx = -1;
319 int32_t dequant_output_idx = -1;
320 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
321 const auto op = quantized_graph->operators()->Get(i);
322 const uint32_t op_code_idx = op->opcode_index();
323 if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
324 BuiltinOperator_DEQUANTIZE) {
325 dequant_input_idx = op->inputs()->Get(0);
326 dequant_output_idx = op->outputs()->Get(0);
327 }
328 }
329 ASSERT_GT(dequant_input_idx, -1);
330 ASSERT_GT(dequant_output_idx, -1);
331 for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
332 const auto quant_tensor = quantized_graph->tensors()->Get(i);
333 // If the tensor is a weight, it should have type FLOAT16.
334 // If the tensor is a bias, it should have type FLOAT16.
335 // If the tensor is an input or output it should have type FLOAT32.
336 // The input to dequantize should be FLOAT16, and all other tensors should
337 // be FLOAT32.
338 if (i == dequant_input_idx) {
339 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
340 } else if (i == dequant_output_idx) {
341 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
342 } else if (IsModelInputOrOutput(output_model, i)) {
343 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
344 } else if (quant_tensor->name()->str() == "conv_bias") {
345 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
346 } else if (quant_tensor->buffer() != 0) {
347 // If it's a non-bias constant tensor, it must be the weight.
348 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
349 } else {
350 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
351 }
352 }
353 }
354 }
355
TEST_F(QuantizeWeightsTest,SharedWeights_Hybrid)356 TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) {
357 LoadSharedWeightsModel();
358 flatbuffers::FlatBufferBuilder builder;
359 auto status =
360 QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
361 EXPECT_EQ(status, kTfLiteOk);
362
363 const uint8_t* buffer = builder.GetBufferPointer();
364 const Model* output_model = GetModel(buffer);
365 ASSERT_TRUE(output_model);
366
367 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
368 uint32_t num_conv_ops = 0;
369 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
370 ++subgraph_idx) {
371 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
372 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
373 const auto op = quantized_graph->operators()->Get(i);
374 const uint32_t op_code_idx = op->opcode_index();
375 const auto op_code =
376 GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
377 if (op_code == BuiltinOperator_CONV_2D) {
378 num_conv_ops++;
379 // Ensure that each convolution's weights tensor is now INT8.
380 const auto weights_tensor =
381 quantized_graph->tensors()->Get(op->inputs()->Get(1));
382 EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
383 }
384 }
385 }
386 // Ensure that there were exactly two convolutions in the model.
387 EXPECT_EQ(num_conv_ops, 2);
388 }
389
TEST_F(QuantizeWeightsTest,SharedWeights_Dequantize)390 TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
391 LoadSharedWeightsModel();
392 flatbuffers::FlatBufferBuilder builder;
393 auto status = internal::QuantizeWeights(&builder, model_, 0,
394 /*use_hybrid_evaluation*/ false,
395 QuantizerType::OLD_QUANTIZER);
396 EXPECT_EQ(status, kTfLiteOk);
397
398 const uint8_t* buffer = builder.GetBufferPointer();
399 const Model* output_model = GetModel(buffer);
400 ASSERT_TRUE(output_model);
401
402 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
403 uint32_t num_conv_ops = 0;
404 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
405 ++subgraph_idx) {
406 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
407 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
408 const auto op = quantized_graph->operators()->Get(i);
409 const uint32_t op_code_idx = op->opcode_index();
410 const auto op_code =
411 GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
412 if (op_code == BuiltinOperator_CONV_2D) {
413 num_conv_ops++;
414 // Ensure that each convolution's weights tensor is still FLOAT
415 // (the output of the dequantize).
416 uint32_t weights_tensor_index = op->inputs()->Get(1);
417 const auto weights_tensor =
418 quantized_graph->tensors()->Get(weights_tensor_index);
419 EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
420
421 // Check that it comes from a dequantize operation.
422 BuiltinOperator producer_op_code;
423 ASSERT_TRUE(GetProducerOpCode(output_model, subgraph_idx,
424 weights_tensor_index, &producer_op_code));
425 EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
426 }
427 }
428 }
429 // Ensure that there were exactly two convolutions in the model.
430 EXPECT_EQ(num_conv_ops, 2);
431 }
432
TEST_F(QuantizeWeightsTest,VerifyGatherQuantization)433 TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
434 LoadGatherTestModel();
435 flatbuffers::FlatBufferBuilder builder;
436 auto status =
437 QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
438 EXPECT_EQ(status, kTfLiteOk);
439
440 const uint8_t* buffer = builder.GetBufferPointer();
441 const Model* output_model = GetModel(buffer);
442 ASSERT_TRUE(output_model);
443
444 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
445 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
446 ++subgraph_idx) {
447 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
448 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
449 const auto op = quantized_graph->operators()->Get(i);
450 const uint32_t op_code_idx = op->opcode_index();
451 const auto op_code =
452 GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
453 if (op_code == BuiltinOperator_GATHER) {
454 uint32_t input_tensor_index = op->inputs()->Get(0);
455 const auto weights_tensor =
456 quantized_graph->tensors()->Get(input_tensor_index);
457 EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
458 }
459 }
460 }
461 }
462
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationDequantize)463 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) {
464 LoadCustomOpTestModel();
465
466 // The custom op is not hybrid, and the second input is a constant that can
467 // be quantized.
468 CustomOpMap custom_op_map;
469 custom_op_map["CustomTestOp"] = {
470 .quantizable_input_indices = {1},
471 .is_hybrid = false,
472 };
473
474 flatbuffers::FlatBufferBuilder builder;
475 auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
476 QuantizerType::OLD_QUANTIZER);
477 ASSERT_EQ(status, kTfLiteOk);
478
479 const uint8_t* buffer = builder.GetBufferPointer();
480 const Model* output_model = GetModel(buffer);
481 ASSERT_TRUE(output_model);
482
483 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
484 const auto quantized_graph = output_model->subgraphs()->Get(0);
485 // A dequantize op should be added.
486 ASSERT_EQ(quantized_graph->operators()->size(),
487 model_->subgraphs()->Get(0)->operators()->size() + 1);
488 int num_custom_ops_found = 0;
489 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
490 const auto op = quantized_graph->operators()->Get(i);
491 const uint32_t op_code_idx = op->opcode_index();
492 const auto op_code =
493 GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
494 if (op_code == BuiltinOperator_CUSTOM) {
495 uint32_t weights_tensor_index = op->inputs()->Get(1);
496 const auto weights_tensor =
497 quantized_graph->tensors()->Get(weights_tensor_index);
498 EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
499
500 // Check that it comes from a dequantize operation.
501 BuiltinOperator producer_op_code;
502 ASSERT_TRUE(GetProducerOpCode(output_model, 0, weights_tensor_index,
503 &producer_op_code));
504 EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
505 num_custom_ops_found++;
506 }
507 }
508 EXPECT_EQ(num_custom_ops_found, 1);
509 }
510
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationHybrid)511 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) {
512 LoadCustomOpTestModel();
513
514 // The custom op is hybrid, and the second input is a constant that can
515 // be quantized.
516 CustomOpMap custom_op_map;
517 custom_op_map["CustomTestOp"] = {
518 .quantizable_input_indices = {1},
519 .is_hybrid = true,
520 };
521
522 flatbuffers::FlatBufferBuilder builder;
523 auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
524 QuantizerType::OLD_QUANTIZER);
525 ASSERT_EQ(status, kTfLiteOk);
526
527 const uint8_t* buffer = builder.GetBufferPointer();
528 const Model* output_model = GetModel(buffer);
529 ASSERT_TRUE(output_model);
530
531 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
532 const auto quantized_graph = output_model->subgraphs()->Get(0);
533 ASSERT_EQ(quantized_graph->operators()->size(),
534 model_->subgraphs()->Get(0)->operators()->size());
535 int num_custom_ops_found = 0;
536 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
537 const auto op = quantized_graph->operators()->Get(i);
538 const uint32_t op_code_idx = op->opcode_index();
539 const auto op_code =
540 GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
541 if (op_code == BuiltinOperator_CUSTOM) {
542 uint32_t weights_tensor_index = op->inputs()->Get(1);
543 const auto weights_tensor =
544 quantized_graph->tensors()->Get(weights_tensor_index);
545 EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
546 num_custom_ops_found++;
547 }
548 }
549 EXPECT_EQ(num_custom_ops_found, 1);
550 }
551
TEST_F(QuantizeWeightsTest,VerifyUpdatedHybridSchemeFalseQuantizationHybrid)552 TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) {
553 LoadBasicModel();
554 flatbuffers::FlatBufferBuilder builder;
555 const CustomOpMap custom_op_map;
556 auto status = QuantizeWeights(
557 &builder, model_, 0, custom_op_map, /*use_updated_hybrid_scheme=*/false,
558 /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER);
559 EXPECT_EQ(status, kTfLiteOk);
560
561 const uint8_t* buffer = builder.GetBufferPointer();
562 const Model* output_model = GetModel(buffer);
563 ASSERT_TRUE(output_model);
564
565 // Nothing should change.
566 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
567 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
568 subgraph_idx++) {
569 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
570 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
571 ASSERT_EQ(quantized_graph->tensors()->size(),
572 float_graph->tensors()->size());
573 // Make sure the graph only has one Conv operation.
574 ASSERT_EQ(quantized_graph->operators()->size(), 1);
575 const auto op = quantized_graph->operators()->Get(0);
576 const uint32_t op_code_idx = op->opcode_index();
577 ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
578 BuiltinOperator_CONV_2D);
579 for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
580 const auto quant_tensor = quantized_graph->tensors()->Get(i);
581 const auto float_tensor = float_graph->tensors()->Get(i);
582 EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
583 EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
584 EXPECT_EQ(GetAsVector(quant_tensor->shape()),
585 GetAsVector(float_tensor->shape()));
586 EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
587 // If the tensor is a weight, it should have type INT8, otherwise it
588 // should stay with type FLOAT32.
589 // If the tensor is a bias, it should have type FLOAT32.
590 if (quant_tensor->name()->str() == "conv_bias") {
591 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
592 } else if (IsModelInputOrOutput(output_model, i)) {
593 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
594 } else if (quant_tensor->buffer() != 0) {
595 EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
596 << quant_tensor->name()->str();
597 auto shape = GetAsVector(quant_tensor->shape());
598 EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
599 } else {
600 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
601 }
602 }
603 }
604 }
605
TEST_F(QuantizeWeightsTest,DequantizeConvBlocklisted)606 TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) {
607 LoadBasicModel();
608 flatbuffers::FlatBufferBuilder builder;
609 const CustomOpMap custom_op_map;
610 auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
611 /*use_updated_hybrid_scheme=*/true,
612 /*op_denylist*/ {BuiltinOperator_CONV_2D},
613 QuantizerType::OLD_QUANTIZER);
614 EXPECT_EQ(status, kTfLiteOk);
615
616 const uint8_t* buffer = builder.GetBufferPointer();
617 const Model* output_model = GetModel(buffer);
618 ASSERT_TRUE(output_model);
619
620 ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
621 for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
622 ++subgraph_idx) {
623 const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
624 const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
625 // The output graph should have an extra tensor from the added dequantize
626 // op.
627 ASSERT_EQ(quantized_graph->tensors()->size(),
628 float_graph->tensors()->size() + 1);
629 // Check that a dequantize op exists.
630 int32_t dequant_input_idx = -1;
631 int32_t dequant_output_idx = -1;
632 for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
633 const auto op = quantized_graph->operators()->Get(i);
634 const uint32_t op_code_idx = op->opcode_index();
635 if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
636 BuiltinOperator_DEQUANTIZE) {
637 dequant_input_idx = op->inputs()->Get(0);
638 dequant_output_idx = op->outputs()->Get(0);
639 }
640 }
641 ASSERT_GT(dequant_input_idx, -1);
642 ASSERT_GT(dequant_output_idx, -1);
643 for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
644 const auto quant_tensor = quantized_graph->tensors()->Get(i);
645 // If the tensor is a weight, it should have type INT8.
646 // If the tensor is a bias, it should have type FLOAT32.
647 // If the tensor is an input or output it should have type FLOAT32.
648 // The input to dequantize should be INT8, and all other tensors should be
649 // FLOAT32.
650 if (i == dequant_input_idx) {
651 EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
652 // The dequantize should still be quantized per-channel
653 EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 5);
654 EXPECT_EQ(quant_tensor->quantization()->quantized_dimension(), 0);
655 } else if (i == dequant_output_idx) {
656 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
657 } else if (IsModelInputOrOutput(output_model, i)) {
658 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
659 } else if (quant_tensor->name()->str() == "conv_bias") {
660 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
661 } else if (quant_tensor->buffer() != 0) {
662 // If it's a non-bias constant tensor, it must be the weight.
663 EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
664 } else {
665 EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
666 }
667 }
668 }
669 }
670
671 } // namespace
672 } // namespace optimize
673 } // namespace tflite
674
main(int argc,char ** argv)675 int main(int argc, char** argv) {
676 tensorflow::string model_file;
677 const std::vector<tensorflow::Flag> flag_list = {
678 tensorflow::Flag("test_model_file", &model_file,
679 "Path to test tflite model file."),
680 };
681
682 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
683 if (!parse_result) {
684 std::cerr << "Required test_model_file\n";
685 std::abort();
686 }
687 g_test_model_dir =
688 new tensorflow::string(tensorflow::io::Dirname(model_file));
689 ::tensorflow::port::InitMain(argv[0], &argc, &argv);
690 return RUN_ALL_TESTS();
691 }
692