• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16 
17 #include <cstddef>
18 #include <cstdint>
19 #include <memory>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/util/command_line_flags.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 #include "tensorflow/lite/tools/optimize/test_util.h"
32 
33 namespace {
34 tensorflow::string* g_test_model_dir = nullptr;
35 }  // namespace
36 
37 namespace tflite {
38 namespace optimize {
39 namespace {
40 
ReadTestModel()41 std::unique_ptr<FlatBufferModel> ReadTestModel() {
42   auto model_path = tensorflow::io::JoinPath(
43       *g_test_model_dir, internal::kConvModelWith0Plus10Weights);
44   return FlatBufferModel::BuildFromFile(model_path.c_str());
45 }
46 
ReadSharedWeightsTestModel()47 std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
48   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
49                                              internal::kModelWithSharedWeights);
50   return FlatBufferModel::BuildFromFile(model_path.c_str());
51 }
52 
ReadGatherTestModel()53 std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
54   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
55                                              internal::kQuantizedWithGather);
56   return FlatBufferModel::BuildFromFile(model_path.c_str());
57 }
58 
ReadCustomOpTestModel()59 std::unique_ptr<FlatBufferModel> ReadCustomOpTestModel() {
60   auto model_path =
61       tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp);
62   return FlatBufferModel::BuildFromFile(model_path.c_str());
63 }
64 
65 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)66 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
67   return std::vector<T>(vec->begin(), vec->end());
68 }
69 
70 class QuantizeWeightsTest : public testing::Test {
71  protected:
QuantizeWeightsTest()72   QuantizeWeightsTest() {}
73 
LoadBasicModel()74   void LoadBasicModel() {
75     input_model_ = ReadTestModel();
76     model_ = input_model_->GetModel();
77   }
78 
LoadSharedWeightsModel()79   void LoadSharedWeightsModel() {
80     input_model_ = ReadSharedWeightsTestModel();
81     model_ = input_model_->GetModel();
82   }
83 
LoadGatherTestModel()84   void LoadGatherTestModel() {
85     input_model_ = ReadGatherTestModel();
86     model_ = input_model_->GetModel();
87   }
88 
LoadCustomOpTestModel()89   void LoadCustomOpTestModel() {
90     input_model_ = ReadCustomOpTestModel();
91     model_ = input_model_->GetModel();
92   }
93 
94   std::unique_ptr<FlatBufferModel> input_model_;
95   const Model* model_;
96 
IsModelInputOrOutput(const Model * model,uint32_t tensor_idx)97   bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) {
98     for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
99          ++subgraph_idx) {
100       const auto subgraph = model->subgraphs()->Get(subgraph_idx);
101       for (size_t i = 0; i < subgraph->inputs()->size(); ++i) {
102         if (subgraph->inputs()->Get(i) == tensor_idx) {
103           return true;
104         }
105       }
106       for (size_t i = 0; i < subgraph->outputs()->size(); ++i) {
107         if (subgraph->outputs()->Get(i) == tensor_idx) {
108           return true;
109         }
110       }
111     }
112     return false;
113   }
114 
115   // Returns the producer op code of the specified tensor_idx.
GetProducerOpCode(const Model * model,uint32_t subgraph_idx,uint32_t tensor_idx,tflite::BuiltinOperator * op_code)116   bool GetProducerOpCode(const Model* model, uint32_t subgraph_idx,
117                          uint32_t tensor_idx,
118                          tflite::BuiltinOperator* op_code) {
119     const auto subgraph = model->subgraphs()->Get(subgraph_idx);
120     for (size_t op_idx = 0; op_idx < subgraph->operators()->size(); ++op_idx) {
121       const auto op = subgraph->operators()->Get(op_idx);
122       for (size_t i = 0; i < op->outputs()->size(); ++i) {
123         if (op->outputs()->Get(i) == tensor_idx) {
124           const uint32_t op_code_idx = op->opcode_index();
125           *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx));
126           return true;
127         }
128       }
129     }
130     return false;
131   }
132 };
133 
TEST_F(QuantizeWeightsTest,QuantizationSucceeds)134 TEST_F(QuantizeWeightsTest, QuantizationSucceeds) {
135   LoadBasicModel();
136   flatbuffers::FlatBufferBuilder builder;
137   auto status =
138       QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
139   EXPECT_EQ(status, kTfLiteOk);
140 
141   const uint8_t* buffer = builder.GetBufferPointer();
142   const Model* output_model = GetModel(buffer);
143   ASSERT_TRUE(output_model);
144 }
145 
TEST_F(QuantizeWeightsTest,WeightsMinNumElements)146 TEST_F(QuantizeWeightsTest, WeightsMinNumElements) {
147   LoadBasicModel();
148   // Make weights_min_size sufficiently large such that no quantization should
149   // happen, i.e. the original model is the same size as the old one.
150   flatbuffers::FlatBufferBuilder builder;
151   const uint64_t kWeightsMinNumElements = 1000000;
152   EXPECT_EQ(QuantizeWeights(&builder, model_, kWeightsMinNumElements,
153                             QuantizerType::OLD_QUANTIZER),
154             kTfLiteOk);
155 
156   const uint8_t* buffer = builder.GetBufferPointer();
157   const Model* output_model = GetModel(buffer);
158   ASSERT_TRUE(output_model);
159 
160   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
161        subgraph_idx++) {
162     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
163     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
164     ASSERT_EQ(quantized_graph->tensors()->size(),
165               float_graph->tensors()->size());
166     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
167       const auto quant_tensor = quantized_graph->tensors()->Get(i);
168       const auto float_tensor = float_graph->tensors()->Get(i);
169       // Everything should remain equal between the two graphs.
170       EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
171       EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
172       EXPECT_EQ(GetAsVector(quant_tensor->shape()),
173                 GetAsVector(float_tensor->shape()));
174       EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
175       EXPECT_EQ(quant_tensor->type(), float_tensor->type());
176     }
177   }
178 }
179 
TEST_F(QuantizeWeightsTest,HybridConv)180 TEST_F(QuantizeWeightsTest, HybridConv) {
181   LoadBasicModel();
182   flatbuffers::FlatBufferBuilder builder;
183   auto status =
184       QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
185   EXPECT_EQ(status, kTfLiteOk);
186 
187   const uint8_t* buffer = builder.GetBufferPointer();
188   const Model* output_model = GetModel(buffer);
189   ASSERT_TRUE(output_model);
190 
191   // Nothing should change.
192   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
193   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
194        subgraph_idx++) {
195     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
196     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
197     ASSERT_EQ(quantized_graph->tensors()->size(),
198               float_graph->tensors()->size());
199     // Make sure the graph only has one Conv operation.
200     ASSERT_EQ(quantized_graph->operators()->size(), 1);
201     const auto op = quantized_graph->operators()->Get(0);
202     const uint32_t op_code_idx = op->opcode_index();
203     ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
204               BuiltinOperator_CONV_2D);
205     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
206       const auto quant_tensor = quantized_graph->tensors()->Get(i);
207       const auto float_tensor = float_graph->tensors()->Get(i);
208       EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
209       EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
210       EXPECT_EQ(GetAsVector(quant_tensor->shape()),
211                 GetAsVector(float_tensor->shape()));
212       EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
213       // If the tensor is a weight, it should have type INT8, otherwise it
214       // should stay with type FLOAT32.
215       // If the tensor is a bias, it should have type FLOAT32.
216       if (quant_tensor->name()->str() == "conv_bias") {
217         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
218       } else if (IsModelInputOrOutput(output_model, i)) {
219         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
220       } else if (quant_tensor->buffer() != 0) {
221         EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
222             << quant_tensor->name()->str();
223         auto shape = GetAsVector(quant_tensor->shape());
224         if (kUseUpdatedHybridSchemeDefault) {
225           EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
226         } else {
227           EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
228         }
229       } else {
230         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
231       }
232     }
233   }
234 }
235 
TEST_F(QuantizeWeightsTest,DequantizeConv)236 TEST_F(QuantizeWeightsTest, DequantizeConv) {
237   LoadBasicModel();
238   flatbuffers::FlatBufferBuilder builder;
239   auto status = internal::QuantizeWeights(&builder, model_, 0,
240                                           /*use_hybrid_evaluation=*/false,
241                                           QuantizerType::OLD_QUANTIZER);
242   EXPECT_EQ(status, kTfLiteOk);
243 
244   const uint8_t* buffer = builder.GetBufferPointer();
245   const Model* output_model = GetModel(buffer);
246   ASSERT_TRUE(output_model);
247 
248   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
249   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
250        ++subgraph_idx) {
251     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
252     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
253     // The output graph should have an extra tensor from the added dequantize
254     // op.
255     ASSERT_EQ(quantized_graph->tensors()->size(),
256               float_graph->tensors()->size() + 1);
257     // Check that a dequantize op exists.
258     int32_t dequant_input_idx = -1;
259     int32_t dequant_output_idx = -1;
260     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
261       const auto op = quantized_graph->operators()->Get(i);
262       const uint32_t op_code_idx = op->opcode_index();
263       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
264           BuiltinOperator_DEQUANTIZE) {
265         dequant_input_idx = op->inputs()->Get(0);
266         dequant_output_idx = op->outputs()->Get(0);
267       }
268     }
269     ASSERT_GT(dequant_input_idx, -1);
270     ASSERT_GT(dequant_output_idx, -1);
271     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
272       const auto quant_tensor = quantized_graph->tensors()->Get(i);
273       // If the tensor is a weight, it should have type INT8.
274       // If the tensor is a bias, it should have type FLOAT32.
275       // If the tensor is an input or output it should have type FLOAT32.
276       // The input to dequantize should be INT8, and all other tensors should be
277       // FLOAT32.
278       if (i == dequant_input_idx) {
279         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
280       } else if (i == dequant_output_idx) {
281         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
282       } else if (IsModelInputOrOutput(output_model, i)) {
283         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
284       } else if (quant_tensor->name()->str() == "conv_bias") {
285         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
286       } else if (quant_tensor->buffer() != 0) {
287         // If it's a non-bias constant tensor, it must be the weight.
288         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
289       } else {
290         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
291       }
292     }
293   }
294 }
295 
TEST_F(QuantizeWeightsTest,DequantizeConvFloat16)296 TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) {
297   LoadBasicModel();
298   flatbuffers::FlatBufferBuilder builder;
299   auto status = tflite::optimize::QuantizeWeights(
300       &builder, model_, BufferType::QUANTIZED_FLOAT16,
301       kUseUpdatedHybridSchemeDefault, QuantizerType::OLD_QUANTIZER);
302   EXPECT_EQ(status, kTfLiteOk);
303 
304   const uint8_t* buffer = builder.GetBufferPointer();
305   const Model* output_model = GetModel(buffer);
306   ASSERT_TRUE(output_model);
307 
308   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
309   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
310        ++subgraph_idx) {
311     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
312     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
313     // The output graph should have two extra tensors from the added dequantize
314     // op.
315     ASSERT_EQ(quantized_graph->tensors()->size(),
316               float_graph->tensors()->size() + 2);
317     // Check that a dequantize op exists.
318     int32_t dequant_input_idx = -1;
319     int32_t dequant_output_idx = -1;
320     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
321       const auto op = quantized_graph->operators()->Get(i);
322       const uint32_t op_code_idx = op->opcode_index();
323       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
324           BuiltinOperator_DEQUANTIZE) {
325         dequant_input_idx = op->inputs()->Get(0);
326         dequant_output_idx = op->outputs()->Get(0);
327       }
328     }
329     ASSERT_GT(dequant_input_idx, -1);
330     ASSERT_GT(dequant_output_idx, -1);
331     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
332       const auto quant_tensor = quantized_graph->tensors()->Get(i);
333       // If the tensor is a weight, it should have type FLOAT16.
334       // If the tensor is a bias, it should have type FLOAT16.
335       // If the tensor is an input or output it should have type FLOAT32.
336       // The input to dequantize should be FLOAT16, and all other tensors should
337       // be FLOAT32.
338       if (i == dequant_input_idx) {
339         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
340       } else if (i == dequant_output_idx) {
341         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
342       } else if (IsModelInputOrOutput(output_model, i)) {
343         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
344       } else if (quant_tensor->name()->str() == "conv_bias") {
345         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
346       } else if (quant_tensor->buffer() != 0) {
347         // If it's a non-bias constant tensor, it must be the weight.
348         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
349       } else {
350         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
351       }
352     }
353   }
354 }
355 
TEST_F(QuantizeWeightsTest,SharedWeights_Hybrid)356 TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) {
357   LoadSharedWeightsModel();
358   flatbuffers::FlatBufferBuilder builder;
359   auto status =
360       QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
361   EXPECT_EQ(status, kTfLiteOk);
362 
363   const uint8_t* buffer = builder.GetBufferPointer();
364   const Model* output_model = GetModel(buffer);
365   ASSERT_TRUE(output_model);
366 
367   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
368   uint32_t num_conv_ops = 0;
369   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
370        ++subgraph_idx) {
371     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
372     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
373       const auto op = quantized_graph->operators()->Get(i);
374       const uint32_t op_code_idx = op->opcode_index();
375       const auto op_code =
376           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
377       if (op_code == BuiltinOperator_CONV_2D) {
378         num_conv_ops++;
379         // Ensure that each convolution's weights tensor is now INT8.
380         const auto weights_tensor =
381             quantized_graph->tensors()->Get(op->inputs()->Get(1));
382         EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
383       }
384     }
385   }
386   // Ensure that there were exactly two convolutions in the model.
387   EXPECT_EQ(num_conv_ops, 2);
388 }
389 
TEST_F(QuantizeWeightsTest,SharedWeights_Dequantize)390 TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
391   LoadSharedWeightsModel();
392   flatbuffers::FlatBufferBuilder builder;
393   auto status = internal::QuantizeWeights(&builder, model_, 0,
394                                           /*use_hybrid_evaluation*/ false,
395                                           QuantizerType::OLD_QUANTIZER);
396   EXPECT_EQ(status, kTfLiteOk);
397 
398   const uint8_t* buffer = builder.GetBufferPointer();
399   const Model* output_model = GetModel(buffer);
400   ASSERT_TRUE(output_model);
401 
402   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
403   uint32_t num_conv_ops = 0;
404   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
405        ++subgraph_idx) {
406     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
407     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
408       const auto op = quantized_graph->operators()->Get(i);
409       const uint32_t op_code_idx = op->opcode_index();
410       const auto op_code =
411           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
412       if (op_code == BuiltinOperator_CONV_2D) {
413         num_conv_ops++;
414         // Ensure that each convolution's weights tensor is still FLOAT
415         // (the output of the dequantize).
416         uint32_t weights_tensor_index = op->inputs()->Get(1);
417         const auto weights_tensor =
418             quantized_graph->tensors()->Get(weights_tensor_index);
419         EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
420 
421         // Check that it comes from a dequantize operation.
422         BuiltinOperator producer_op_code;
423         ASSERT_TRUE(GetProducerOpCode(output_model, subgraph_idx,
424                                       weights_tensor_index, &producer_op_code));
425         EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
426       }
427     }
428   }
429   // Ensure that there were exactly two convolutions in the model.
430   EXPECT_EQ(num_conv_ops, 2);
431 }
432 
TEST_F(QuantizeWeightsTest,VerifyGatherQuantization)433 TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
434   LoadGatherTestModel();
435   flatbuffers::FlatBufferBuilder builder;
436   auto status =
437       QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER);
438   EXPECT_EQ(status, kTfLiteOk);
439 
440   const uint8_t* buffer = builder.GetBufferPointer();
441   const Model* output_model = GetModel(buffer);
442   ASSERT_TRUE(output_model);
443 
444   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
445   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
446        ++subgraph_idx) {
447     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
448     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
449       const auto op = quantized_graph->operators()->Get(i);
450       const uint32_t op_code_idx = op->opcode_index();
451       const auto op_code =
452           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
453       if (op_code == BuiltinOperator_GATHER) {
454         uint32_t input_tensor_index = op->inputs()->Get(0);
455         const auto weights_tensor =
456             quantized_graph->tensors()->Get(input_tensor_index);
457         EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
458       }
459     }
460   }
461 }
462 
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationDequantize)463 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) {
464   LoadCustomOpTestModel();
465 
466   // The custom op is not hybrid, and the second input is a constant that can
467   // be quantized.
468   CustomOpMap custom_op_map;
469   custom_op_map["CustomTestOp"] = {
470       .quantizable_input_indices = {1},
471       .is_hybrid = false,
472   };
473 
474   flatbuffers::FlatBufferBuilder builder;
475   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
476                                 QuantizerType::OLD_QUANTIZER);
477   ASSERT_EQ(status, kTfLiteOk);
478 
479   const uint8_t* buffer = builder.GetBufferPointer();
480   const Model* output_model = GetModel(buffer);
481   ASSERT_TRUE(output_model);
482 
483   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
484   const auto quantized_graph = output_model->subgraphs()->Get(0);
485   // A dequantize op should be added.
486   ASSERT_EQ(quantized_graph->operators()->size(),
487             model_->subgraphs()->Get(0)->operators()->size() + 1);
488   int num_custom_ops_found = 0;
489   for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
490     const auto op = quantized_graph->operators()->Get(i);
491     const uint32_t op_code_idx = op->opcode_index();
492     const auto op_code =
493         GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
494     if (op_code == BuiltinOperator_CUSTOM) {
495       uint32_t weights_tensor_index = op->inputs()->Get(1);
496       const auto weights_tensor =
497           quantized_graph->tensors()->Get(weights_tensor_index);
498       EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
499 
500       // Check that it comes from a dequantize operation.
501       BuiltinOperator producer_op_code;
502       ASSERT_TRUE(GetProducerOpCode(output_model, 0, weights_tensor_index,
503                                     &producer_op_code));
504       EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
505       num_custom_ops_found++;
506     }
507   }
508   EXPECT_EQ(num_custom_ops_found, 1);
509 }
510 
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationHybrid)511 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) {
512   LoadCustomOpTestModel();
513 
514   // The custom op is hybrid, and the second input is a constant that can
515   // be quantized.
516   CustomOpMap custom_op_map;
517   custom_op_map["CustomTestOp"] = {
518       .quantizable_input_indices = {1},
519       .is_hybrid = true,
520   };
521 
522   flatbuffers::FlatBufferBuilder builder;
523   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
524                                 QuantizerType::OLD_QUANTIZER);
525   ASSERT_EQ(status, kTfLiteOk);
526 
527   const uint8_t* buffer = builder.GetBufferPointer();
528   const Model* output_model = GetModel(buffer);
529   ASSERT_TRUE(output_model);
530 
531   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
532   const auto quantized_graph = output_model->subgraphs()->Get(0);
533   ASSERT_EQ(quantized_graph->operators()->size(),
534             model_->subgraphs()->Get(0)->operators()->size());
535   int num_custom_ops_found = 0;
536   for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
537     const auto op = quantized_graph->operators()->Get(i);
538     const uint32_t op_code_idx = op->opcode_index();
539     const auto op_code =
540         GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
541     if (op_code == BuiltinOperator_CUSTOM) {
542       uint32_t weights_tensor_index = op->inputs()->Get(1);
543       const auto weights_tensor =
544           quantized_graph->tensors()->Get(weights_tensor_index);
545       EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
546       num_custom_ops_found++;
547     }
548   }
549   EXPECT_EQ(num_custom_ops_found, 1);
550 }
551 
TEST_F(QuantizeWeightsTest,VerifyUpdatedHybridSchemeFalseQuantizationHybrid)552 TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) {
553   LoadBasicModel();
554   flatbuffers::FlatBufferBuilder builder;
555   const CustomOpMap custom_op_map;
556   auto status = QuantizeWeights(
557       &builder, model_, 0, custom_op_map, /*use_updated_hybrid_scheme=*/false,
558       /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER);
559   EXPECT_EQ(status, kTfLiteOk);
560 
561   const uint8_t* buffer = builder.GetBufferPointer();
562   const Model* output_model = GetModel(buffer);
563   ASSERT_TRUE(output_model);
564 
565   // Nothing should change.
566   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
567   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
568        subgraph_idx++) {
569     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
570     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
571     ASSERT_EQ(quantized_graph->tensors()->size(),
572               float_graph->tensors()->size());
573     // Make sure the graph only has one Conv operation.
574     ASSERT_EQ(quantized_graph->operators()->size(), 1);
575     const auto op = quantized_graph->operators()->Get(0);
576     const uint32_t op_code_idx = op->opcode_index();
577     ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
578               BuiltinOperator_CONV_2D);
579     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
580       const auto quant_tensor = quantized_graph->tensors()->Get(i);
581       const auto float_tensor = float_graph->tensors()->Get(i);
582       EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
583       EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
584       EXPECT_EQ(GetAsVector(quant_tensor->shape()),
585                 GetAsVector(float_tensor->shape()));
586       EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
587       // If the tensor is a weight, it should have type INT8, otherwise it
588       // should stay with type FLOAT32.
589       // If the tensor is a bias, it should have type FLOAT32.
590       if (quant_tensor->name()->str() == "conv_bias") {
591         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
592       } else if (IsModelInputOrOutput(output_model, i)) {
593         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
594       } else if (quant_tensor->buffer() != 0) {
595         EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
596             << quant_tensor->name()->str();
597         auto shape = GetAsVector(quant_tensor->shape());
598         EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
599       } else {
600         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
601       }
602     }
603   }
604 }
605 
TEST_F(QuantizeWeightsTest,DequantizeConvBlocklisted)606 TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) {
607   LoadBasicModel();
608   flatbuffers::FlatBufferBuilder builder;
609   const CustomOpMap custom_op_map;
610   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
611                                 /*use_updated_hybrid_scheme=*/true,
612                                 /*op_denylist*/ {BuiltinOperator_CONV_2D},
613                                 QuantizerType::OLD_QUANTIZER);
614   EXPECT_EQ(status, kTfLiteOk);
615 
616   const uint8_t* buffer = builder.GetBufferPointer();
617   const Model* output_model = GetModel(buffer);
618   ASSERT_TRUE(output_model);
619 
620   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
621   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
622        ++subgraph_idx) {
623     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
624     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
625     // The output graph should have an extra tensor from the added dequantize
626     // op.
627     ASSERT_EQ(quantized_graph->tensors()->size(),
628               float_graph->tensors()->size() + 1);
629     // Check that a dequantize op exists.
630     int32_t dequant_input_idx = -1;
631     int32_t dequant_output_idx = -1;
632     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
633       const auto op = quantized_graph->operators()->Get(i);
634       const uint32_t op_code_idx = op->opcode_index();
635       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
636           BuiltinOperator_DEQUANTIZE) {
637         dequant_input_idx = op->inputs()->Get(0);
638         dequant_output_idx = op->outputs()->Get(0);
639       }
640     }
641     ASSERT_GT(dequant_input_idx, -1);
642     ASSERT_GT(dequant_output_idx, -1);
643     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
644       const auto quant_tensor = quantized_graph->tensors()->Get(i);
645       // If the tensor is a weight, it should have type INT8.
646       // If the tensor is a bias, it should have type FLOAT32.
647       // If the tensor is an input or output it should have type FLOAT32.
648       // The input to dequantize should be INT8, and all other tensors should be
649       // FLOAT32.
650       if (i == dequant_input_idx) {
651         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
652         // The dequantize should still be quantized per-channel
653         EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 5);
654         EXPECT_EQ(quant_tensor->quantization()->quantized_dimension(), 0);
655       } else if (i == dequant_output_idx) {
656         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
657       } else if (IsModelInputOrOutput(output_model, i)) {
658         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
659       } else if (quant_tensor->name()->str() == "conv_bias") {
660         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
661       } else if (quant_tensor->buffer() != 0) {
662         // If it's a non-bias constant tensor, it must be the weight.
663         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
664       } else {
665         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
666       }
667     }
668   }
669 }
670 
671 }  // namespace
672 }  // namespace optimize
673 }  // namespace tflite
674 
main(int argc,char ** argv)675 int main(int argc, char** argv) {
676   tensorflow::string model_file;
677   const std::vector<tensorflow::Flag> flag_list = {
678       tensorflow::Flag("test_model_file", &model_file,
679                        "Path to test tflite model file."),
680   };
681 
682   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
683   if (!parse_result) {
684     std::cerr << "Required test_model_file\n";
685     std::abort();
686   }
687   g_test_model_dir =
688       new tensorflow::string(tensorflow::io::Dirname(model_file));
689   ::tensorflow::port::InitMain(argv[0], &argc, &argv);
690   return RUN_ALL_TESTS();
691 }
692