• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 
21 #include <gtest/gtest.h>
22 #include "llvm/ADT/Twine.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/platform/init_main.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28 #include "tensorflow/lite/tools/optimize/test_util.h"
29 
30 // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc
31 
32 namespace {
33 tensorflow::string* g_test_model_dir = nullptr;
34 }  // namespace
35 
36 namespace tflite {
37 namespace optimize {
38 namespace {
39 
40 using mlir::lite::BufferType;
41 using mlir::lite::CustomOpMap;
42 using mlir::lite::QuantizeWeights;
43 constexpr bool kUseUpdatedHybridSchemeDefault = true;
44 
CreateMutableModelFromFile(const Model * input_model)45 std::unique_ptr<ModelT> CreateMutableModelFromFile(const Model* input_model) {
46   auto copied_model = std::make_unique<ModelT>();
47   input_model->UnPackTo(copied_model.get(), nullptr);
48   return copied_model;
49 }
50 
ReadTestModel()51 std::unique_ptr<FlatBufferModel> ReadTestModel() {
52   auto model_path = tensorflow::io::JoinPath(
53       *g_test_model_dir, internal::kConvModelWith0Plus10Weights);
54   return FlatBufferModel::BuildFromFile(model_path.c_str());
55 }
56 
ReadSharedWeightsTestModel()57 std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
58   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
59                                              internal::kModelWithSharedWeights);
60   return FlatBufferModel::BuildFromFile(model_path.c_str());
61 }
62 
ReadGatherTestModel()63 std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
64   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
65                                              internal::kQuantizedWithGather);
66   return FlatBufferModel::BuildFromFile(model_path.c_str());
67 }
68 
ReadCustomOpTestModel()69 std::unique_ptr<FlatBufferModel> ReadCustomOpTestModel() {
70   auto model_path =
71       tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp);
72   return FlatBufferModel::BuildFromFile(model_path.c_str());
73 }
74 
75 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)76 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
77   return std::vector<T>(vec->begin(), vec->end());
78 }
79 
80 class QuantizeWeightsTest : public testing::Test {
81  protected:
QuantizeWeightsTest()82   QuantizeWeightsTest() {}
83 
LoadBasicModel()84   void LoadBasicModel() {
85     input_model_ = ReadTestModel();
86     model_ = input_model_->GetModel();
87   }
88 
LoadSharedWeightsModel()89   void LoadSharedWeightsModel() {
90     input_model_ = ReadSharedWeightsTestModel();
91     model_ = input_model_->GetModel();
92   }
93 
LoadGatherTestModel()94   void LoadGatherTestModel() {
95     input_model_ = ReadGatherTestModel();
96     model_ = input_model_->GetModel();
97   }
98 
LoadCustomOpTestModel()99   void LoadCustomOpTestModel() {
100     input_model_ = ReadCustomOpTestModel();
101     model_ = input_model_->GetModel();
102   }
103 
104   std::unique_ptr<FlatBufferModel> input_model_;
105   const Model* model_;
106 
IsModelInputOrOutput(const Model * model,uint32_t tensor_idx)107   bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) {
108     for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
109          ++subgraph_idx) {
110       const auto subgraph = model->subgraphs()->Get(subgraph_idx);
111       for (size_t i = 0; i < subgraph->inputs()->size(); ++i) {
112         if (subgraph->inputs()->Get(i) == tensor_idx) {
113           return true;
114         }
115       }
116       for (size_t i = 0; i < subgraph->outputs()->size(); ++i) {
117         if (subgraph->outputs()->Get(i) == tensor_idx) {
118           return true;
119         }
120       }
121     }
122     return false;
123   }
124 
125   // Returns the producer op code of the specified tensor_idx.
GetProducerOpCode(const Model * model,uint32_t subgraph_idx,uint32_t tensor_idx,BuiltinOperator * op_code)126   bool GetProducerOpCode(const Model* model, uint32_t subgraph_idx,
127                          uint32_t tensor_idx, BuiltinOperator* op_code) {
128     const auto subgraph = model->subgraphs()->Get(subgraph_idx);
129     for (size_t op_idx = 0; op_idx < subgraph->operators()->size(); ++op_idx) {
130       const auto op = subgraph->operators()->Get(op_idx);
131       for (size_t i = 0; i < op->outputs()->size(); ++i) {
132         if (op->outputs()->Get(i) == tensor_idx) {
133           const uint32_t op_code_idx = op->opcode_index();
134           *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx));
135           return true;
136         }
137       }
138     }
139     return false;
140   }
141 };
142 
ExpectEqualTensor(const Tensor * tensor,const Tensor * expected_tensor)143 bool ExpectEqualTensor(const Tensor* tensor, const Tensor* expected_tensor) {
144   // Everything should remain equal between the two graphs.
145   return (tensor->is_variable() == expected_tensor->is_variable()) &&
146          (GetAsVector(tensor->shape()) ==
147           GetAsVector(expected_tensor->shape())) &&
148          (tensor->name()->str() == expected_tensor->name()->str());
149 }
150 
151 // Finds the match of the quantized tensor from the possible tensors. Each
152 // possible tensors can be used only once. It checks shape and name if the
153 // tensor is quantized and also checks buffer conetens and tensor type if not
154 // quantized. For the quantized case, tensor type and quantizaction params are
155 // expected to be checked in the test body with the match.
FindMatchingExpectedTensor(const Model * quantized_model,const Model * expected_model,const Tensor * quantized_tensor,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * possible_tensors,std::vector<int> & used_tensors,bool quantized=false)156 const Tensor* FindMatchingExpectedTensor(
157     const Model* quantized_model, const Model* expected_model,
158     const Tensor* quantized_tensor,
159     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* possible_tensors,
160     std::vector<int>& used_tensors, bool quantized = false) {
161   std::unique_ptr<ModelT> quant_model =
162       CreateMutableModelFromFile(quantized_model);
163   std::unique_ptr<ModelT> float_model =
164       CreateMutableModelFromFile(expected_model);
165 
166   for (int i = 0; i < possible_tensors->size(); i++) {
167     // Skip if the tensor is already used for match.
168     auto it = std::find(used_tensors.begin(), used_tensors.end(), i);
169     if (it != used_tensors.end()) continue;
170 
171     const Tensor* float_tensor = possible_tensors->Get(i);
172 
173     if (ExpectEqualTensor(quantized_tensor, float_tensor)) {
174       if (quantized && quantized_tensor->name()->str().find("weights")) {
175         // If tensor is quantized, data type and buffer contents can be
176         // different between float and quantized tensors. So do those tests
177         // separately in the test body without checking them here.
178         used_tensors.push_back(i);
179         return float_tensor;
180       } else {
181         // Otherwise, do additional checks for data type and buffer contents.
182         const std::vector<uint8_t> quantized_buffer =
183             quant_model->buffers[quantized_tensor->buffer()].get()->data;
184         const std::vector<uint8_t> float_buffer =
185             float_model->buffers[float_tensor->buffer()].get()->data;
186         if ((quantized_buffer == float_buffer) &&
187             (quantized_tensor->type() == float_tensor->type())) {
188           used_tensors.push_back(i);
189           return float_tensor;
190         }
191       }
192     }
193   }
194   return nullptr;
195 }
196 
TEST_F(QuantizeWeightsTest,QuantizationSucceeds)197 TEST_F(QuantizeWeightsTest, QuantizationSucceeds) {
198   LoadBasicModel();
199   flatbuffers::FlatBufferBuilder builder;
200   auto status = QuantizeWeights(&builder, model_, 0);
201   EXPECT_EQ(status, kTfLiteOk);
202 
203   const uint8_t* buffer = builder.GetBufferPointer();
204   const Model* output_model = GetModel(buffer);
205   ASSERT_TRUE(output_model);
206 }
207 
TEST_F(QuantizeWeightsTest,QuantizationFails)208 TEST_F(QuantizeWeightsTest, QuantizationFails) {
209   LoadBasicModel();
210   flatbuffers::FlatBufferBuilder builder;
211   tflite::StderrReporter error_reporter;
212   auto status = QuantizeWeights(&builder, model_, &error_reporter,
213                                 TensorType_UINT8, {}, {});
214   EXPECT_EQ(status, kTfLiteError);
215 }
216 
TEST_F(QuantizeWeightsTest,WeightsMinNumElements)217 TEST_F(QuantizeWeightsTest, WeightsMinNumElements) {
218   LoadBasicModel();
219   // Make weights_min_size sufficiently large such that no quantization should
220   // happen, i.e. the original model is the same size as the old one.
221   flatbuffers::FlatBufferBuilder builder;
222   const uint64_t kWeightsMinNumElements = 1000000;
223   EXPECT_EQ(QuantizeWeights(&builder, model_, kWeightsMinNumElements),
224             kTfLiteOk);
225 
226   const uint8_t* buffer = builder.GetBufferPointer();
227   const Model* output_model = GetModel(buffer);
228   ASSERT_TRUE(output_model);
229 
230   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
231        subgraph_idx++) {
232     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
233     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
234     ASSERT_EQ(quantized_graph->tensors()->size(),
235               float_graph->tensors()->size());
236     std::vector<int> used_tensors;
237     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
238       const auto quant_tensor = quantized_graph->tensors()->Get(i);
239       const auto float_tensor = FindMatchingExpectedTensor(
240           /*quantized_model=*/output_model, /*expected_model=*/model_,
241           /*quantized_tensor=*/quant_tensor,
242           /*possible_tensors=*/float_graph->tensors(),
243           /*used_tensors=*/used_tensors);
244       EXPECT_NE(float_tensor, nullptr);
245     }
246     EXPECT_EQ(used_tensors.size(), quantized_graph->tensors()->size());
247   }
248 }
249 
TEST_F(QuantizeWeightsTest,HybridConv)250 TEST_F(QuantizeWeightsTest, HybridConv) {
251   LoadBasicModel();
252   flatbuffers::FlatBufferBuilder builder;
253   auto status = QuantizeWeights(&builder, model_, 0);
254   EXPECT_EQ(status, kTfLiteOk);
255 
256   const uint8_t* buffer = builder.GetBufferPointer();
257   const Model* output_model = GetModel(buffer);
258   ASSERT_TRUE(output_model);
259 
260   // Nothing should change.
261   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
262   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
263        subgraph_idx++) {
264     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
265     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
266     ASSERT_EQ(quantized_graph->tensors()->size(),
267               float_graph->tensors()->size());
268     // Make sure the graph only has one Conv operation.
269     ASSERT_EQ(quantized_graph->operators()->size(), 1);
270     const auto op = quantized_graph->operators()->Get(0);
271     const uint32_t op_code_idx = op->opcode_index();
272     ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
273               BuiltinOperator_CONV_2D);
274     std::vector<int> used_tensors;
275     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
276       const auto quant_tensor = quantized_graph->tensors()->Get(i);
277       const auto float_tensor = FindMatchingExpectedTensor(
278           /*quantized_model=*/output_model, /*expected_model=*/model_,
279           /*quantized_tensor=*/quant_tensor,
280           /*possible_tensors=*/float_graph->tensors(),
281           /*used_tensors=*/used_tensors, /*quantized=*/true);
282       EXPECT_NE(float_tensor, nullptr);
283       // If the tensor is a weight, it should have type INT8, otherwise it
284       // should stay with type FLOAT32.
285       // If the tensor is a bias, it should have type FLOAT32.
286       if (quant_tensor->name()->str() == "conv_bias") {
287         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
288       } else if (IsModelInputOrOutput(output_model, i)) {
289         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
290       } else if (quant_tensor->buffer() != 0) {
291         EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
292             << quant_tensor->name()->str();
293         auto shape = GetAsVector(quant_tensor->shape());
294         if (kUseUpdatedHybridSchemeDefault) {
295           EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
296         } else {
297           EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
298         }
299       } else {
300         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
301       }
302     }
303     EXPECT_EQ(used_tensors.size(), quantized_graph->tensors()->size());
304   }
305 }
306 
TEST_F(QuantizeWeightsTest,DequantizeConv)307 TEST_F(QuantizeWeightsTest, DequantizeConv) {
308   LoadBasicModel();
309   flatbuffers::FlatBufferBuilder builder;
310   auto status = QuantizeWeights(&builder, model_, 0,
311                                 /*use_hybrid_evaluation=*/false);
312   EXPECT_EQ(status, kTfLiteOk);
313 
314   const uint8_t* buffer = builder.GetBufferPointer();
315   const Model* output_model = GetModel(buffer);
316   ASSERT_TRUE(output_model);
317 
318   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
319   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
320        ++subgraph_idx) {
321     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
322     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
323     // The output graph should have an extra tensor from the added dequantize
324     // op.
325     ASSERT_EQ(quantized_graph->tensors()->size(),
326               float_graph->tensors()->size() + 1);
327     // Check that a dequantize op exists.
328     int32_t dequant_input_idx = -1;
329     int32_t dequant_output_idx = -1;
330     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
331       const auto op = quantized_graph->operators()->Get(i);
332       const uint32_t op_code_idx = op->opcode_index();
333       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
334           BuiltinOperator_DEQUANTIZE) {
335         dequant_input_idx = op->inputs()->Get(0);
336         dequant_output_idx = op->outputs()->Get(0);
337       }
338     }
339     ASSERT_GT(dequant_input_idx, -1);
340     ASSERT_GT(dequant_output_idx, -1);
341     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
342       const auto quant_tensor = quantized_graph->tensors()->Get(i);
343       // If the tensor is a weight, it should have type INT8.
344       // If the tensor is a bias, it should have type FLOAT32.
345       // If the tensor is an input or output it should have type FLOAT32.
346       // The input to dequantize should be INT8, and all other tensors should be
347       // FLOAT32.
348       if (i == dequant_input_idx) {
349         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
350       } else if (i == dequant_output_idx) {
351         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
352       } else if (IsModelInputOrOutput(output_model, i)) {
353         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
354       } else if (quant_tensor->name()->str() == "conv_bias") {
355         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
356       } else if (quant_tensor->buffer() != 0) {
357         // If it's a non-bias constant tensor, it must be the weight.
358         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
359       } else {
360         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
361       }
362     }
363   }
364 }
365 
TEST_F(QuantizeWeightsTest,DequantizeConvFloat16)366 TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) {
367   LoadBasicModel();
368   flatbuffers::FlatBufferBuilder builder;
369   auto status =
370       QuantizeWeights(&builder, model_, BufferType::QUANTIZED_FLOAT16);
371   EXPECT_EQ(status, kTfLiteOk);
372 
373   const uint8_t* buffer = builder.GetBufferPointer();
374   const Model* output_model = GetModel(buffer);
375   ASSERT_TRUE(output_model);
376 
377   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
378   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
379        ++subgraph_idx) {
380     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
381     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
382     // The output graph should have two extra tensors from the added dequantize
383     // op.
384     ASSERT_EQ(quantized_graph->tensors()->size(),
385               float_graph->tensors()->size() + 2);
386     // Check that a dequantize op exists.
387     int32_t dequant_input_idx = -1;
388     int32_t dequant_output_idx = -1;
389     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
390       const auto op = quantized_graph->operators()->Get(i);
391       const uint32_t op_code_idx = op->opcode_index();
392       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
393           BuiltinOperator_DEQUANTIZE) {
394         dequant_input_idx = op->inputs()->Get(0);
395         dequant_output_idx = op->outputs()->Get(0);
396       }
397     }
398     ASSERT_GT(dequant_input_idx, -1);
399     ASSERT_GT(dequant_output_idx, -1);
400     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
401       const auto quant_tensor = quantized_graph->tensors()->Get(i);
402       // If the tensor is a weight, it should have type FLOAT16.
403       // If the tensor is a bias, it should have type FLOAT16.
404       // If the tensor is an input or output it should have type FLOAT32.
405       // The input to dequantize should be FLOAT16, and all other tensors should
406       // be FLOAT32.
407       if (i == dequant_input_idx) {
408         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
409       } else if (i == dequant_output_idx) {
410         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
411       } else if (IsModelInputOrOutput(output_model, i)) {
412         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
413       } else if (quant_tensor->name()->str() == "conv_bias") {
414         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
415       } else if ((!CreateMutableModelFromFile(output_model)
416                        ->buffers[quant_tensor->buffer()]
417                        .get()
418                        ->data.empty())) {
419         // If it's a non-bias constant tensor, it must be the weight.
420         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
421       } else {
422         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
423       }
424     }
425   }
426 }
427 
TEST_F(QuantizeWeightsTest,SharedWeights_Hybrid)428 TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) {
429   LoadSharedWeightsModel();
430   flatbuffers::FlatBufferBuilder builder;
431   auto status = QuantizeWeights(&builder, model_, 0);
432   EXPECT_EQ(status, kTfLiteOk);
433 
434   const uint8_t* buffer = builder.GetBufferPointer();
435   const Model* output_model = GetModel(buffer);
436   ASSERT_TRUE(output_model);
437 
438   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
439   uint32_t num_conv_ops = 0;
440   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
441        ++subgraph_idx) {
442     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
443     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
444       const auto op = quantized_graph->operators()->Get(i);
445       const uint32_t op_code_idx = op->opcode_index();
446       const auto op_code =
447           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
448       if (op_code == BuiltinOperator_CONV_2D) {
449         num_conv_ops++;
450         // Ensure that each convolution's weights tensor is now INT8.
451         const auto weights_tensor =
452             quantized_graph->tensors()->Get(op->inputs()->Get(1));
453         EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
454       }
455     }
456   }
457   // Ensure that there were exactly two convolutions in the model.
458   EXPECT_EQ(num_conv_ops, 2);
459 }
460 
TEST_F(QuantizeWeightsTest,SharedWeights_Dequantize)461 TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
462   LoadSharedWeightsModel();
463   flatbuffers::FlatBufferBuilder builder;
464   auto status = QuantizeWeights(&builder, model_, 0,
465                                 /*use_hybrid_evaluation=*/false);
466   EXPECT_EQ(status, kTfLiteOk);
467 
468   const uint8_t* buffer = builder.GetBufferPointer();
469   const Model* output_model = GetModel(buffer);
470   ASSERT_TRUE(output_model);
471 
472   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
473   uint32_t num_conv_ops = 0;
474   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
475        ++subgraph_idx) {
476     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
477     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
478       const auto op = quantized_graph->operators()->Get(i);
479       const uint32_t op_code_idx = op->opcode_index();
480       const auto op_code =
481           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
482       if (op_code == BuiltinOperator_CONV_2D) {
483         num_conv_ops++;
484         // Ensure that each convolution's weights tensor is still FLOAT
485         // (the output of the dequantize).
486         uint32_t weights_tensor_index = op->inputs()->Get(1);
487         const auto weights_tensor =
488             quantized_graph->tensors()->Get(weights_tensor_index);
489         EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
490 
491         // Check that it comes from a dequantize operation.
492         BuiltinOperator producer_op_code;
493         ASSERT_TRUE(GetProducerOpCode(output_model, subgraph_idx,
494                                       weights_tensor_index, &producer_op_code));
495         EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
496       }
497     }
498   }
499   // Ensure that there were exactly two convolutions in the model.
500   EXPECT_EQ(num_conv_ops, 2);
501 }
502 
TEST_F(QuantizeWeightsTest,VerifyGatherQuantization)503 TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
504   LoadGatherTestModel();
505   flatbuffers::FlatBufferBuilder builder;
506   auto status = QuantizeWeights(&builder, model_, 0);
507   EXPECT_EQ(status, kTfLiteOk);
508 
509   const uint8_t* buffer = builder.GetBufferPointer();
510   const Model* output_model = GetModel(buffer);
511   ASSERT_TRUE(output_model);
512 
513   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
514   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
515        ++subgraph_idx) {
516     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
517     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
518       const auto op = quantized_graph->operators()->Get(i);
519       const uint32_t op_code_idx = op->opcode_index();
520       const auto op_code =
521           GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
522       if (op_code == BuiltinOperator_GATHER) {
523         uint32_t input_tensor_index = op->inputs()->Get(0);
524         const auto weights_tensor =
525             quantized_graph->tensors()->Get(input_tensor_index);
526         EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
527       }
528     }
529   }
530 }
531 
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationDequantize)532 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) {
533   LoadCustomOpTestModel();
534 
535   // The custom op is not hybrid, and the second input is a constant that can
536   // be quantized.
537   CustomOpMap custom_op_map;
538   custom_op_map["CustomTestOp"] = {
539       {1},   // quantizable_input_indices
540       true,  // is_weight_only
541   };
542 
543   flatbuffers::FlatBufferBuilder builder;
544   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map);
545   ASSERT_EQ(status, kTfLiteOk);
546 
547   const uint8_t* buffer = builder.GetBufferPointer();
548   const Model* output_model = GetModel(buffer);
549   ASSERT_TRUE(output_model);
550 
551   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
552   const auto quantized_graph = output_model->subgraphs()->Get(0);
553   // A dequantize op should be added.
554   ASSERT_EQ(quantized_graph->operators()->size(),
555             model_->subgraphs()->Get(0)->operators()->size() + 1);
556   int num_custom_ops_found = 0;
557   for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
558     const auto op = quantized_graph->operators()->Get(i);
559     const uint32_t op_code_idx = op->opcode_index();
560     const auto op_code =
561         GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
562     if (op_code == BuiltinOperator_CUSTOM) {
563       uint32_t weights_tensor_index = op->inputs()->Get(1);
564       const auto weights_tensor =
565           quantized_graph->tensors()->Get(weights_tensor_index);
566       EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32);
567 
568       // Check that it comes from a dequantize operation.
569       BuiltinOperator producer_op_code;
570       ASSERT_TRUE(GetProducerOpCode(output_model, 0, weights_tensor_index,
571                                     &producer_op_code));
572       EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE);
573       num_custom_ops_found++;
574     }
575   }
576   EXPECT_EQ(num_custom_ops_found, 1);
577 }
578 
TEST_F(QuantizeWeightsTest,VerifyCustomOpQuantizationHybrid)579 TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) {
580   LoadCustomOpTestModel();
581 
582   // The custom op is dynamic range quantizable, and the second input is a
583   // constant that can be quantized.
584   CustomOpMap custom_op_map;
585   custom_op_map["CustomTestOp"] = {
586       {1},    // quantizable_input_indices
587       false,  // is_weight_only
588   };
589 
590   flatbuffers::FlatBufferBuilder builder;
591   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map);
592   ASSERT_EQ(status, kTfLiteOk);
593 
594   const uint8_t* buffer = builder.GetBufferPointer();
595   const Model* output_model = GetModel(buffer);
596   ASSERT_TRUE(output_model);
597 
598   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
599   const auto quantized_graph = output_model->subgraphs()->Get(0);
600   ASSERT_EQ(quantized_graph->operators()->size(),
601             model_->subgraphs()->Get(0)->operators()->size());
602   int num_custom_ops_found = 0;
603   for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
604     const auto op = quantized_graph->operators()->Get(i);
605     const uint32_t op_code_idx = op->opcode_index();
606     const auto op_code =
607         GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx));
608     if (op_code == BuiltinOperator_CUSTOM) {
609       uint32_t weights_tensor_index = op->inputs()->Get(1);
610       const auto weights_tensor =
611           quantized_graph->tensors()->Get(weights_tensor_index);
612       EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
613       num_custom_ops_found++;
614     }
615   }
616   EXPECT_EQ(num_custom_ops_found, 1);
617 }
618 
TEST_F(QuantizeWeightsTest,VerifyUpdatedHybridSchemeFalseQuantizationHybrid)619 TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) {
620   LoadBasicModel();
621   flatbuffers::FlatBufferBuilder builder;
622   const CustomOpMap custom_op_map;
623   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, false);
624   EXPECT_EQ(status, kTfLiteOk);
625 
626   const uint8_t* buffer = builder.GetBufferPointer();
627   const Model* output_model = GetModel(buffer);
628   ASSERT_TRUE(output_model);
629 
630   // Nothing should change.
631   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
632   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
633        subgraph_idx++) {
634     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
635     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
636     ASSERT_EQ(quantized_graph->tensors()->size(),
637               float_graph->tensors()->size());
638     // Make sure the graph only has one Conv operation.
639     ASSERT_EQ(quantized_graph->operators()->size(), 1);
640     const auto op = quantized_graph->operators()->Get(0);
641     const uint32_t op_code_idx = op->opcode_index();
642     ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)),
643               BuiltinOperator_CONV_2D);
644     std::vector<int> used_tensors;
645     for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
646       const auto quant_tensor = quantized_graph->tensors()->Get(i);
647       const auto float_tensor = FindMatchingExpectedTensor(
648           /*quantized_model=*/output_model, /*expected_model=*/model_,
649           /*quantized_tensor=*/quant_tensor,
650           /*possible_tensors=*/float_graph->tensors(),
651           /*used_tensors=*/used_tensors, /*quantized=*/true);
652       EXPECT_NE(float_tensor, nullptr);
653       // If the tensor is a weight, it should have type INT8, otherwise it
654       // should stay with type FLOAT32.
655       // If the tensor is a bias, it should have type FLOAT32.
656       if (quant_tensor->name()->str() == "conv_bias") {
657         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
658       } else if (IsModelInputOrOutput(output_model, i)) {
659         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
660       } else if ((!CreateMutableModelFromFile(output_model)
661                        ->buffers[quant_tensor->buffer()]
662                        .get()
663                        ->data.empty())) {
664         EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
665             << quant_tensor->name()->str();
666         auto shape = GetAsVector(quant_tensor->shape());
667         EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
668       } else {
669         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
670       }
671     }
672     EXPECT_EQ(used_tensors.size(), quantized_graph->tensors()->size());
673   }
674 }
675 
TEST_F(QuantizeWeightsTest,DequantizeConvBlocklisted)676 TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) {
677   LoadBasicModel();
678   flatbuffers::FlatBufferBuilder builder;
679   const CustomOpMap custom_op_map;
680   auto status = QuantizeWeights(&builder, model_, 0, custom_op_map,
681                                 /*use_updated_hybrid_scheme=*/true,
682                                 {BuiltinOperator_CONV_2D});
683   EXPECT_EQ(status, kTfLiteOk);
684 
685   const uint8_t* buffer = builder.GetBufferPointer();
686   const Model* output_model = GetModel(buffer);
687   ASSERT_TRUE(output_model);
688 
689   ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
690   for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
691        ++subgraph_idx) {
692     const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
693     const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
694     // The output graph should have an extra tensor from the added dequantize
695     // op.
696     ASSERT_EQ(quantized_graph->tensors()->size(),
697               float_graph->tensors()->size() + 1);
698     // Check that a dequantize op exists.
699     int32_t dequant_input_idx = -1;
700     int32_t dequant_output_idx = -1;
701     for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
702       const auto op = quantized_graph->operators()->Get(i);
703       const uint32_t op_code_idx = op->opcode_index();
704       if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) ==
705           BuiltinOperator_DEQUANTIZE) {
706         dequant_input_idx = op->inputs()->Get(0);
707         dequant_output_idx = op->outputs()->Get(0);
708       }
709     }
710     ASSERT_GT(dequant_input_idx, -1);
711     ASSERT_GT(dequant_output_idx, -1);
712     for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
713       const auto quant_tensor = quantized_graph->tensors()->Get(i);
714       // If the tensor is a weight, it should have type INT8.
715       // If the tensor is a bias, it should have type FLOAT32.
716       // If the tensor is an input or output it should have type FLOAT32.
717       // The input to dequantize should be INT8, and all other tensors should be
718       // FLOAT32.
719       if (i == dequant_input_idx) {
720         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
721         // The dequantize should still be quantized per-channel
722         EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 5);
723         EXPECT_EQ(quant_tensor->quantization()->quantized_dimension(), 0);
724       } else if (i == dequant_output_idx) {
725         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
726       } else if (IsModelInputOrOutput(output_model, i)) {
727         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
728       } else if (quant_tensor->name()->str() == "conv_bias") {
729         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
730       } else if ((!CreateMutableModelFromFile(output_model)
731                        ->buffers[quant_tensor->buffer()]
732                        .get()
733                        ->data.empty())) {
734         // If it's a non-bias constant tensor, it must be the weight.
735         EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
736       } else {
737         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
738       }
739     }
740   }
741 }
742 
743 }  // namespace
744 }  // namespace optimize
745 }  // namespace tflite
746 
main(int argc,char ** argv)747 int main(int argc, char** argv) {
748   tensorflow::string model_file;
749   const std::vector<tensorflow::Flag> flag_list = {
750       tensorflow::Flag("test_model_file", &model_file,
751                        "Path to test tflite model file."),
752   };
753 
754   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
755   if (!parse_result) {
756     std::cerr << "Required test_model_file\n";
757     std::abort();
758   }
759   g_test_model_dir =
760       new tensorflow::string(tensorflow::io::Dirname(model_file));
761   ::tensorflow::port::InitMain(argv[0], &argc, &argv);
762   return RUN_ALL_TESTS();
763 }
764