• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_model.h"
16 
17 #include <cstddef>
18 #include <cstdint>
19 #include <memory>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/util/command_line_flags.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 #include "tensorflow/lite/tools/optimize/test_util.h"
32 
33 // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc
34 
35 namespace {
36 tensorflow::string* g_test_model_dir = nullptr;
37 }  // namespace
38 
39 namespace tflite {
40 namespace optimize {
41 namespace {
42 
ReadModel(const string & model_name)43 std::unique_ptr<FlatBufferModel> ReadModel(const string& model_name) {
44   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
45   return FlatBufferModel::BuildFromFile(model_path.c_str());
46 }
47 
48 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)49 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
50   return std::vector<T>(vec->begin(), vec->end());
51 }
52 
VerifyAsymmetricQuantizationScale(const QuantizationParameters & float_quant_params,const QuantizationParametersT & quantized_quant_params)53 void VerifyAsymmetricQuantizationScale(
54     const QuantizationParameters& float_quant_params,
55     const QuantizationParametersT& quantized_quant_params) {
56   const float eps = 1e-7;
57   ASSERT_EQ(float_quant_params.min()->size(), 1);
58   ASSERT_EQ(float_quant_params.max()->size(), 1);
59   float float_min = std::min(0.f, float_quant_params.min()->Get(0));
60   float float_max = std::max(0.f, float_quant_params.max()->Get(0));
61 
62   ASSERT_EQ(quantized_quant_params.scale.size(), 1);
63   ASSERT_EQ(quantized_quant_params.zero_point.size(), 1);
64 
65   float scale = (float_max - float_min) / 255;
66   EXPECT_NEAR(scale, quantized_quant_params.scale[0], eps);
67 }
68 
69 class QuantizeModelTest : public testing::Test {
70  protected:
QuantizeModelTest()71   QuantizeModelTest() {
72     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
73     readonly_model_ = input_model_->GetModel();
74     readonly_model_->UnPackTo(&model_);
75   }
76 
77   std::unique_ptr<FlatBufferModel> input_model_;
78   const Model* readonly_model_;
79   tflite::ModelT model_;
80   flatbuffers::FlatBufferBuilder builder_;
81   internal::FailOnErrorReporter error_reporter_;
82 };
83 
ExpectSameModels(const ModelT & model,const ModelT & expected_model)84 void ExpectSameModels(const ModelT& model, const ModelT& expected_model) {
85   ASSERT_EQ(model.subgraphs.size(), expected_model.subgraphs.size());
86   for (size_t subgraph_idx = 0; subgraph_idx < model.subgraphs.size();
87        subgraph_idx++) {
88     const auto graph = model.subgraphs[subgraph_idx].get();
89     const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
90     ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size());
91     for (size_t i = 0; i < graph->tensors.size(); i++) {
92       const auto tensor = graph->tensors[i].get();
93       const auto expected_tensor = expected_graph->tensors[i].get();
94       EXPECT_EQ(tensor->buffer, expected_tensor->buffer);
95       EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
96       EXPECT_EQ(tensor->shape, expected_tensor->shape);
97       EXPECT_EQ(tensor->name, expected_tensor->name);
98       EXPECT_EQ(tensor->type, expected_tensor->type);
99       const auto quantization_params = tensor->quantization.get();
100       const auto expected_quantization_params =
101           expected_tensor->quantization.get();
102       if (quantization_params != nullptr ||
103           expected_quantization_params != nullptr) {
104         EXPECT_NE(quantization_params, nullptr);
105         EXPECT_NE(expected_quantization_params, nullptr);
106         EXPECT_EQ(quantization_params->scale,
107                   expected_quantization_params->scale);
108         EXPECT_EQ(quantization_params->zero_point,
109                   expected_quantization_params->zero_point);
110       }
111     }
112   }
113   ASSERT_EQ(model.buffers.size(), expected_model.buffers.size());
114   for (size_t buffer_idx = 0; buffer_idx < model.buffers.size(); ++buffer_idx) {
115     const auto buffer = model.buffers[buffer_idx].get()->data;
116     const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data;
117     EXPECT_EQ(buffer, expected_buffer);
118   }
119   // TODO(jianlijianli): Compare operators as well.
120 }
121 
122 class QuantizeConvModelTest : public QuantizeModelTest,
123                               public testing::WithParamInterface<TensorType> {
124  protected:
QuantizeConvModelTest()125   QuantizeConvModelTest() {
126     tensor_type_ = GetParam();
127     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
128     readonly_model_ = input_model_->GetModel();
129     readonly_model_->UnPackTo(&model_);
130   }
131   TensorType tensor_type_;
132 };
133 
134 INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
135                          testing::ValuesIn({TensorType_INT8,
136                                             TensorType_INT16}));
137 
TEST_P(QuantizeConvModelTest,QuantizationSucceeds)138 TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
139   auto status =
140       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
141                                 false, tensor_type_, &error_reporter_);
142   EXPECT_EQ(status, kTfLiteOk);
143   const uint8_t* buffer = builder_.GetBufferPointer();
144   const Model* output_model = GetModel(buffer);
145   ASSERT_TRUE(output_model);
146 }
147 
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayer)148 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
149   auto status = QuantizeModel(
150       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
151       /*allow_float=*/true, {}, TensorType_FLOAT32, &error_reporter_);
152   EXPECT_EQ(status, kTfLiteOk);
153   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
154   // The resulting model should be the same.
155   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
156   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
157        subgraph_idx++) {
158     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
159     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
160     ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
161     for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
162       const auto quant_tensor = quantized_graph->tensors[i].get();
163       const auto float_tensor = float_graph->tensors()->Get(i);
164       EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
165       EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
166       EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
167       EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
168       EXPECT_EQ(quant_tensor->type, float_tensor->type());
169     }
170   }
171 }
172 
TEST_P(QuantizeConvModelTest,TensorShapesAndStructureIsUnchanged)173 TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
174   auto status =
175       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
176                                 false, tensor_type_, &error_reporter_);
177   EXPECT_EQ(status, kTfLiteOk);
178   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
179   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
180        subgraph_idx++) {
181     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
182     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
183     ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
184     for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
185       const auto quant_tensor = quantized_graph->tensors[i].get();
186       const auto float_tensor = float_graph->tensors()->Get(i);
187       EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
188       EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
189       EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
190       EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
191     }
192   }
193   // check op and versioning.
194   EXPECT_EQ(model_.operator_codes.size(), 1);
195   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
196             BuiltinOperator_CONV_2D);
197   EXPECT_EQ(model_.operator_codes[0]->version, 3);
198 }
199 
TEST_P(QuantizeConvModelTest,OperatorsAreUnchanged)200 TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
201   auto status =
202       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
203                                 false, tensor_type_, &error_reporter_);
204   EXPECT_EQ(status, kTfLiteOk);
205   ASSERT_EQ(model_.operator_codes.size(),
206             readonly_model_->operator_codes()->size());
207   for (size_t i = 0; i < model_.operator_codes.size(); i++) {
208     const auto float_model_op = readonly_model_->operator_codes()->Get(i);
209     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[i].get()),
210               GetBuiltinCode(float_model_op));
211     if (GetBuiltinCode(model_.operator_codes[i].get()) ==
212         BuiltinOperator_CONV_2D) {
213       EXPECT_EQ(model_.operator_codes[i]->version, 3);
214     } else {
215       EXPECT_EQ(model_.operator_codes[i]->version, 2);
216     }
217   }
218 
219   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
220   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
221        subgraph_idx++) {
222     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
223     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
224     ASSERT_EQ(quantized_graph->operators.size(),
225               float_graph->operators()->size());
226     for (size_t i = 0; i < quantized_graph->operators.size(); i++) {
227       const auto quant_op = quantized_graph->operators[i].get();
228       const auto float_op = float_graph->operators()->Get(i);
229       EXPECT_EQ(quant_op->inputs, GetAsVector(float_op->inputs()));
230       EXPECT_EQ(quant_op->outputs, GetAsVector(float_op->outputs()));
231       EXPECT_EQ(quant_op->opcode_index, float_op->opcode_index());
232     }
233   }
234 }
235 
TEST_P(QuantizeConvModelTest,GraphIsFullyQuantized)236 TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
237   auto status = QuantizeModelAllOperators(
238       &builder_, &model_, tensor_type_, tensor_type_,
239       /*allow_float*/ false, tensor_type_, &error_reporter_);
240   EXPECT_EQ(status, kTfLiteOk);
241   for (const auto& subgraph : model_.subgraphs) {
242     for (const auto& tensor : subgraph->tensors) {
243       if (tensor_type_ == TensorType_INT8) {
244         EXPECT_TRUE(tensor->type == TensorType_INT32 ||
245                     tensor->type == TensorType_INT8);
246       } else if (tensor_type_ == TensorType_INT16) {
247         EXPECT_TRUE(tensor->type == TensorType_INT64 ||  // bias
248                     tensor->type == TensorType_INT8 ||   // weights
249                     tensor->type == TensorType_INT16);   // activations
250       }
251     }
252   }
253 }
254 
TEST_P(QuantizeConvModelTest,FloatInputAndOutput)255 TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
256   auto status = QuantizeModelAllOperators(
257       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
258       /*allow_float*/ false, tensor_type_, &error_reporter_);
259   EXPECT_EQ(status, kTfLiteOk);
260 
261   for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
262        ++subgraph_idx) {
263     const auto& subgraph = model_.subgraphs[subgraph_idx];
264     const auto& readonly_subgraph =
265         readonly_model_->subgraphs()->Get(subgraph_idx);
266     // The model has one input and output, so the converted model should have
267     // two extra ops, a Quantize and Dequantize.
268     EXPECT_EQ(subgraph->operators.size(),
269               readonly_subgraph->operators()->size() + 2);
270     // Check that the first op is Quantize and the last is Dequant.
271     const auto& quant_op = subgraph->operators[0];
272     const auto& dequant_op =
273         subgraph->operators[subgraph->operators.size() - 1];
274     const int32_t quant_idx = quant_op->opcode_index;
275     const int32_t dequant_idx = dequant_op->opcode_index;
276     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
277               BuiltinOperator_QUANTIZE);
278     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
279               BuiltinOperator_DEQUANTIZE);
280     // The model should only have one input and output.
281     EXPECT_EQ(subgraph->inputs.size(), 1);
282     EXPECT_EQ(subgraph->outputs.size(), 1);
283     const int32_t input_idx = subgraph->inputs[0];
284     const int32_t output_idx = subgraph->outputs[0];
285     // Ensure: new input -> Quant -> old input.
286     EXPECT_EQ(quant_op->inputs[0], input_idx);
287     EXPECT_EQ(quant_op->outputs[0], readonly_subgraph->inputs()->Get(0));
288     // Ensure: old output -> dequant -> new output.
289     EXPECT_EQ(dequant_op->inputs[0], readonly_subgraph->outputs()->Get(0));
290     EXPECT_EQ(dequant_op->outputs[0], output_idx);
291     // The input and output types should be float.
292     EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_FLOAT32);
293     EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
294     EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
295     EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
296     // The original input and output has been renamed.
297     std::string control_suffix =
298         (tensor_type_ == TensorType_INT16) ? "int16" : "int8";
299     EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
300               "input_" + control_suffix);
301     EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
302               "output_" + control_suffix);
303     for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
304          ++tensor_idx) {
305       const auto& tensor = subgraph->tensors[tensor_idx];
306       if (input_idx != tensor_idx && output_idx != tensor_idx) {
307         if (tensor_type_ == TensorType_INT8) {
308           EXPECT_TRUE(tensor->type == TensorType_INT32 ||
309                       tensor->type == TensorType_INT8);
310         } else if (tensor_type_ == TensorType_INT16) {
311           EXPECT_TRUE(tensor->type == TensorType_INT64 ||  // bias
312                       tensor->type == TensorType_INT8 ||   // weights
313                       tensor->type == TensorType_INT16);   // activations
314         }
315       }
316     }
317   }
318 }
319 
TEST_P(QuantizeConvModelTest,Uint8InputAndOutput)320 TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
321   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_UINT8,
322                                           TensorType_UINT8, false,
323                                           TensorType_INT8, &error_reporter_);
324   EXPECT_EQ(status, kTfLiteOk);
325 
326   for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
327        ++subgraph_idx) {
328     const auto& subgraph = model_.subgraphs[subgraph_idx];
329     const auto& readonly_subgraph =
330         readonly_model_->subgraphs()->Get(subgraph_idx);
331     // The model has one input and output, so the converted model should have
332     // two extra ops, a Quantize and Dequantize.
333     EXPECT_EQ(subgraph->operators.size(),
334               readonly_subgraph->operators()->size() + 2);
335     // Check that the first op is Quantize and the last is Dequant.
336     const auto& quant_op_uint8_int8 = subgraph->operators[0];
337     const auto& quant_op_int8_uint8 =
338         subgraph->operators[subgraph->operators.size() - 1];
339     const int32_t quant_op_uint8_int8_idx = quant_op_uint8_int8->opcode_index;
340     const int32_t quant_op_int8_uint8_idx = quant_op_int8_uint8->opcode_index;
341     EXPECT_EQ(
342         GetBuiltinCode(model_.operator_codes[quant_op_uint8_int8_idx].get()),
343         BuiltinOperator_QUANTIZE);
344     EXPECT_EQ(
345         GetBuiltinCode(model_.operator_codes[quant_op_int8_uint8_idx].get()),
346         BuiltinOperator_QUANTIZE);
347     // The model should only have one input and output.
348     EXPECT_EQ(subgraph->inputs.size(), 1);
349     EXPECT_EQ(subgraph->outputs.size(), 1);
350     const int32_t input_idx = subgraph->inputs[0];
351     const int32_t output_idx = subgraph->outputs[0];
352     // Ensure: new input -> Quant -> old input.
353     EXPECT_EQ(quant_op_uint8_int8->inputs[0], input_idx);
354     EXPECT_EQ(quant_op_uint8_int8->outputs[0],
355               readonly_subgraph->inputs()->Get(0));
356     // Ensure: old output -> dequant -> new output.
357     EXPECT_EQ(quant_op_int8_uint8->inputs[0],
358               readonly_subgraph->outputs()->Get(0));
359     EXPECT_EQ(quant_op_int8_uint8->outputs[0], output_idx);
360     // The input and output types should be uint8.
361     EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_UINT8);
362     EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
363     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->scale.size(), 1);
364     EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->scale[0],
365                     0.0392156877);
366     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point.size(), 1);
367     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point[0], 0);
368     EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_UINT8);
369     EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
370     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->scale.size(), 1);
371     EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->scale[0],
372                     0.0392156877);
373     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point.size(),
374               1);
375     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point[0], 0);
376     // The original input and output has been renamed.
377     EXPECT_EQ(subgraph->tensors[quant_op_uint8_int8->outputs[0]]->name,
378               "input_int8");
379     EXPECT_EQ(subgraph->tensors[quant_op_int8_uint8->inputs[0]]->name,
380               "output_int8");
381     for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
382          ++tensor_idx) {
383       const auto& tensor = subgraph->tensors[tensor_idx];
384       if (input_idx != tensor_idx && output_idx != tensor_idx) {
385         EXPECT_TRUE(tensor->type == TensorType_INT32 ||
386                     tensor->type == TensorType_INT8);
387       }
388     }
389   }
390 }
391 
392 class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
393  protected:
QuantizeConvNoBiasModelTest()394   QuantizeConvNoBiasModelTest() {
395     input_model_ = ReadModel(internal::kConvModelWithNoBias);
396     readonly_model_ = input_model_->GetModel();
397     readonly_model_->UnPackTo(&model_);
398   }
399 };
400 
TEST_F(QuantizeConvNoBiasModelTest,QuantizationSucceeds)401 TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
402   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
403                                           TensorType_INT8, false,
404                                           TensorType_INT8, &error_reporter_);
405   EXPECT_EQ(status, kTfLiteOk);
406   const uint8_t* buffer = builder_.GetBufferPointer();
407   const Model* output_model = GetModel(buffer);
408   ASSERT_TRUE(output_model);
409 }
410 
411 class QuantizeConcatModelTest : public QuantizeModelTest,
412                                 public testing::WithParamInterface<TensorType> {
413  protected:
QuantizeConcatModelTest()414   QuantizeConcatModelTest() {
415     input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
416     readonly_model_ = input_model_->GetModel();
417     readonly_model_->UnPackTo(&model_);
418   }
419 
SetUp()420   void SetUp() override { tensor_type_ = GetParam(); }
421 
422   TensorType tensor_type_;
423 };
424 
425 // There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
426 // as min/max and "input1" has [0, 10] as min/max. The output "output" for
427 // concat has [0, 10] as min/max.
428 // After applyging QuantizeModel(), "input0" will have a requant op added, along
429 // with a tensor "input0_reqaunt" that has [0, 10] as min/max. So the topology
430 // becomes:
431 // input0 -> requant -> input0_requant \
432 //                                       concat - output
433 //                              input1 /
TEST_P(QuantizeConcatModelTest,AddRequantBeforeConcat)434 TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
435   auto status =
436       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
437                                 false, tensor_type_, &error_reporter_);
438   EXPECT_EQ(status, kTfLiteOk);
439 
440   // There is only one subgraph.
441   const int32_t subgraph_idx = 0;
442   const auto& subgraph = model_.subgraphs[subgraph_idx];
443   const auto& readonly_subgraph =
444       readonly_model_->subgraphs()->Get(subgraph_idx);
445 
446   // There should be two ops: quant and concat.
447   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
448   EXPECT_EQ(subgraph->operators.size(), 2);
449   const auto& requant = subgraph->operators[0];
450   const auto& concat = subgraph->operators[1];
451   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[requant->opcode_index].get()),
452             BuiltinOperator_QUANTIZE);
453   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[concat->opcode_index].get()),
454             BuiltinOperator_CONCATENATION);
455 
456   auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
457   /*
458      input0_scale_control
459         INT8: (5-0) / (2^8 - 1)
460         INT16: (5-0) / (2^16 / 2 - 1)
461      input1_scale
462         INT8: (10-0) / (2^8 - 1)
463         INT16: (10-0) / (2^16 / 2 - 1)
464   */
465   auto input0_scale_control =
466       tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
467   auto input1_scale =
468       tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
469 
470   // There should be 4 tensors: input0, input1, input0_requantized, output.
471   EXPECT_EQ(subgraph->tensors.size(), 4);
472   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
473   EXPECT_EQ(subgraph->tensors[0]->name, "input0");
474   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
475   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
476   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
477                   input0_scale_control);
478   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
479                   zero_point_control);
480   EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
481   EXPECT_EQ(subgraph->tensors[1]->name, "input1");
482   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
483   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
484   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
485   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
486                   zero_point_control);
487   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
488   EXPECT_EQ(subgraph->tensors[2]->name, "output");
489   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
490   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
491   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
492   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
493                   zero_point_control);
494   EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
495   EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
496   EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
497   EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
498   EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
499   EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
500                   zero_point_control);
501 
502   // The connection should be what is described in the comment.
503   EXPECT_EQ(requant->inputs.size(), 1);
504   EXPECT_EQ(requant->outputs.size(), 1);
505   EXPECT_EQ(requant->inputs[0], 0);
506   EXPECT_EQ(requant->outputs[0], 3);
507   EXPECT_EQ(concat->inputs.size(), 2);
508   EXPECT_EQ(concat->outputs.size(), 1);
509   EXPECT_EQ(concat->inputs[0], 3);
510   EXPECT_EQ(concat->inputs[1], 1);
511   EXPECT_EQ(concat->outputs[0], 2);
512 
513   // check op and versioning.
514   EXPECT_EQ(model_.operator_codes.size(), 2);
515   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
516             BuiltinOperator_CONCATENATION);
517   EXPECT_EQ(model_.operator_codes[0]->version, 2);
518   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
519             BuiltinOperator_QUANTIZE);
520   EXPECT_EQ(model_.operator_codes[1]->version, 2);
521 }
522 INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
523                          testing::ValuesIn({TensorType_INT8,
524                                             TensorType_INT16}));
525 class QuantizeSplitModelTest : public QuantizeModelTest {
526  protected:
QuantizeSplitModelTest()527   QuantizeSplitModelTest() {
528     input_model_ = ReadModel(internal::kModelSplit);
529     readonly_model_ = input_model_->GetModel();
530     readonly_model_->UnPackTo(&model_);
531   }
532 };
533 
534 // There are two outputs for split with different scales, the resulting model
535 // should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest,QuantizeSplit)536 TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
537   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
538                                           TensorType_INT8, false,
539                                           TensorType_INT8, &error_reporter_);
540   EXPECT_EQ(status, kTfLiteOk);
541 
542   // There is only one subgraph.
543   const int32_t subgraph_idx = 0;
544   const auto& subgraph = model_.subgraphs[subgraph_idx];
545   const auto& readonly_subgraph =
546       readonly_model_->subgraphs()->Get(subgraph_idx);
547 
548   // There should be two ops: the split and add in the original model.
549   EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
550   EXPECT_EQ(subgraph->operators.size(), 2);
551   const auto& split = subgraph->operators[0];
552   const auto& add = subgraph->operators[1];
553   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[split->opcode_index].get()),
554             BuiltinOperator_SPLIT);
555   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[add->opcode_index].get()),
556             BuiltinOperator_ADD);
557 
558   // There should be 5 tensors: input, output, split, split/split_dim, split:1.
559   EXPECT_EQ(subgraph->tensors.size(), 5);
560 
561   EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
562   EXPECT_EQ(subgraph->tensors[0]->name, "input");
563   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
564   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
565   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 1.0);
566   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
567   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
568   EXPECT_EQ(subgraph->tensors[1]->name, "output");
569   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
570   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
571   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 1.0);
572   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
573   EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
574   EXPECT_EQ(subgraph->tensors[2]->name, "split");
575   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
576   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
577   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 1.0);
578   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
579   EXPECT_EQ(subgraph->tensors[4]->type, TensorType_INT8);
580   EXPECT_EQ(subgraph->tensors[4]->name, "split:1");
581   EXPECT_EQ(subgraph->tensors[4]->quantization->scale.size(), 1);
582   EXPECT_EQ(subgraph->tensors[4]->quantization->zero_point.size(), 1);
583   EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->scale[0], 1.0);
584   EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->zero_point[0], -128);
585 
586   // check op and versioning.
587   EXPECT_EQ(model_.operator_codes.size(), 2);
588   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
589             BuiltinOperator_SPLIT);
590   EXPECT_EQ(model_.operator_codes[0]->version, 2);
591 }
592 
593 class QuantizeConvModel1Test : public QuantizeModelTest {
594  protected:
QuantizeConvModel1Test()595   QuantizeConvModel1Test() {
596     input_model_ = ReadModel(internal::kConvModelWithMinus128Plus127Weights);
597     readonly_model_ = input_model_->GetModel();
598     readonly_model_->UnPackTo(&model_);
599   }
600 };
601 
TEST_F(QuantizeConvModel1Test,VerifyConvQuantizationWithUnitScale)602 TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
603   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
604                                           TensorType_INT8, false,
605                                           TensorType_INT8, &error_reporter_);
606   EXPECT_EQ(status, kTfLiteOk);
607   const auto& subgraph = model_.subgraphs[0];
608 
609   auto conv_op = subgraph->operators[0].get();
610   const int input_tensor_idx = 0;
611   const int weights_tensor_idx = 1;
612   const int bias_tensor_index = 2;
613   const int output_tensor_idx = 0;
614   const auto bias_tensor =
615       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
616   const auto input_tensor =
617       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
618   const auto weights_tensor =
619       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
620   const auto output_tensor =
621       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
622 
623   EXPECT_EQ(bias_tensor->type, TensorType_INT32);
624   EXPECT_EQ(input_tensor->type, TensorType_INT8);
625   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
626 
627   ASSERT_TRUE(weights_tensor->quantization);
628   const int out_channel_size = weights_tensor->shape[0];
629   ASSERT_TRUE(bias_tensor->quantization);
630   ASSERT_TRUE(weights_tensor->quantization);
631   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
632   const std::vector<float>& weights_scales =
633       weights_tensor->quantization->scale;
634 
635   const std::vector<int64_t>& weights_zero_points =
636       weights_tensor->quantization->zero_point;
637 
638   ASSERT_EQ(bias_scales.size(), out_channel_size);
639   ASSERT_EQ(weights_scales.size(), out_channel_size);
640   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
641   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
642   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
643 
644   for (size_t i = 0; i < out_channel_size; i++) {
645     EXPECT_EQ(weights_scales[i], 1);
646     EXPECT_EQ(bias_scales[i], 1);
647     EXPECT_EQ(weights_zero_points[i], 0);
648   }
649 
650   EXPECT_EQ(input_tensor->quantization->scale[0], 1);
651   EXPECT_EQ(output_tensor->quantization->scale[0], 1);
652 
653   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
654   ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
655   const int32_t* bias_values =
656       reinterpret_cast<int32_t*>(bias_buffer->data.data());
657   const auto original_bias_buffer =
658       readonly_model_->buffers()->Get(bias_tensor->buffer);
659   const float* bias_float_buffer =
660       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
661 
662   const float eps = 1e-7;
663   for (size_t i = 0; i < bias_tensor->shape[0]; i++) {
664     const float bias_scale =
665         input_tensor->quantization->scale[0] * weights_scales[i];
666     auto dequantized_value = bias_values[i] * bias_scale;
667     EXPECT_NEAR(dequantized_value, bias_float_buffer[i], eps);
668   }
669 
670   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
671   const auto original_weights_buffer =
672       readonly_model_->buffers()->Get(weights_tensor->buffer);
673   const int8_t* weight_values =
674       reinterpret_cast<int8_t*>(weights_buffer->data.data());
675   const float* weights_float_buffer =
676       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
677   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
678             original_weights_buffer->data()->size());
679   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
680   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
681     for (size_t j = 0; j < num_values_in_channel; j++) {
682       size_t element_idx = channel_idx * out_channel_size + j;
683       auto dequantized_value =
684           weight_values[element_idx] * weights_scales[channel_idx];
685       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx], eps);
686     }
687   }
688 
689   // check op and versioning.
690   EXPECT_EQ(model_.operator_codes.size(), 1);
691   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
692             BuiltinOperator_CONV_2D);
693   EXPECT_EQ(model_.operator_codes[0]->version, 3);
694 }
695 
696 class QuantizeConvModel2Test : public QuantizeModelTest,
697                                public testing::WithParamInterface<TensorType> {
698  protected:
QuantizeConvModel2Test()699   QuantizeConvModel2Test() {
700     tensor_type_ = GetParam();
701     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
702     readonly_model_ = input_model_->GetModel();
703     readonly_model_->UnPackTo(&model_);
704   }
705 
706   TensorType tensor_type_;
707 };
708 INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
709                          testing::ValuesIn({TensorType_INT8,
710                                             TensorType_INT16}));
711 
TEST_P(QuantizeConvModel2Test,VerifyConvQuantization)712 TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
713   auto status =
714       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
715                                 false, tensor_type_, &error_reporter_);
716   ASSERT_EQ(kTfLiteOk, status);
717   const auto& subgraph = model_.subgraphs[0];
718   auto conv_op = subgraph->operators[0].get();
719   const int input_tensor_idx = 0;
720   const int weights_tensor_idx = 1;
721   const int bias_tensor_index = 2;
722   const int output_tensor_idx = 0;
723   const auto bias_tensor =
724       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
725   const auto input_tensor =
726       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
727   const auto weights_tensor =
728       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
729   const auto output_tensor =
730       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
731 
732   EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
733                                    ? TensorType_INT32
734                                    : TensorType_INT64);
735   EXPECT_EQ(input_tensor->type, tensor_type_);
736   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
737 
738   ASSERT_TRUE(weights_tensor->quantization);
739   const int out_channel_size = weights_tensor->shape[0];
740   ASSERT_TRUE(bias_tensor->quantization);
741   ASSERT_TRUE(weights_tensor->quantization);
742   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
743   const std::vector<float>& weights_scales =
744       weights_tensor->quantization->scale;
745   const std::vector<int64_t>& weights_zero_points =
746       weights_tensor->quantization->zero_point;
747 
748   ASSERT_EQ(bias_scales.size(), out_channel_size);
749   ASSERT_EQ(weights_scales.size(), out_channel_size);
750   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
751   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
752   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
753 
754   const float eps = 1e-7;
755 
756   // Bias scale should be input * per_channel_weight_scale.
757   for (size_t i = 0; i < out_channel_size; i++) {
758     EXPECT_NEAR(bias_scales[i],
759                 input_tensor->quantization->scale[0] * weights_scales[i], eps);
760   }
761 
762   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
763   auto control_size = tensor_type_ == TensorType_INT8
764                           ? sizeof(int32_t) * bias_tensor->shape[0]
765                           : sizeof(int64_t) * bias_tensor->shape[0];
766 
767   ASSERT_EQ(bias_buffer->data.size(), control_size);
768   const auto original_bias_buffer =
769       readonly_model_->buffers()->Get(bias_tensor->buffer);
770   const float* bias_float_buffer =
771       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
772 
773   if (tensor_type_ == TensorType_INT8) {
774     int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
775     for (size_t i = 0; i < out_channel_size; i++) {
776       auto dequantized_value = bias_values[i] * bias_scales[i];
777       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
778     }
779   } else if (tensor_type_ == TensorType_INT16) {
780     int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
781     for (size_t i = 0; i < out_channel_size; i++) {
782       auto dequantized_value = bias_values[i] * bias_scales[i];
783       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
784     }
785   }
786 
787   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
788   const auto original_weights_buffer =
789       readonly_model_->buffers()->Get(weights_tensor->buffer);
790   const int8_t* weight_values =
791       reinterpret_cast<int8_t*>(weights_buffer->data.data());
792   const float* weights_float_buffer =
793       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
794   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
795             original_weights_buffer->data()->size());
796   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
797   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
798     for (size_t j = 0; j < num_values_in_channel; j++) {
799       size_t element_idx = channel_idx * out_channel_size + j;
800       auto scale = weights_scales[channel_idx];
801       auto zero_point = weights_zero_points[channel_idx];
802       auto dequantized_value = weight_values[element_idx] * scale;
803       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
804                   scale / 2);
805       EXPECT_EQ(zero_point, 0);
806     }
807   }
808 
809   // check op and versioning.
810   EXPECT_EQ(model_.operator_codes.size(), 1);
811   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
812             BuiltinOperator_CONV_2D);
813   EXPECT_EQ(model_.operator_codes[0]->version, 3);
814 }
815 
816 class QuantizeSoftmaxTest : public QuantizeModelTest {
817  protected:
QuantizeSoftmaxTest()818   QuantizeSoftmaxTest() {
819     input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5);
820     readonly_model_ = input_model_->GetModel();
821     readonly_model_->UnPackTo(&model_);
822   }
823 };
824 
TEST_F(QuantizeSoftmaxTest,VerifySoftmaxQuantization)825 TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
826   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
827                                           TensorType_INT8, false,
828                                           TensorType_INT8, &error_reporter_);
829   ASSERT_EQ(kTfLiteOk, status);
830 
831   const auto& subgraph = model_.subgraphs[0];
832   auto op = subgraph->operators[0].get();
833   // Model has a single softmax op.
834   ASSERT_EQ(op->opcode_index, 0);
835   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
836             BuiltinOperator_SOFTMAX);
837 
838   ASSERT_EQ(op->inputs.size(), 1);
839   ASSERT_EQ(op->outputs.size(), 1);
840   auto float_graph = readonly_model_->subgraphs()->Get(0);
841 
842   // Verify input.
843   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
844             TensorType_FLOAT32);
845   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
846             TensorType_FLOAT32);
847 
848   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
849   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
850 
851   auto float_input_quant_params =
852       float_graph->tensors()->Get(op->inputs[0])->quantization();
853   auto input_quant_params =
854       subgraph->tensors[op->inputs[0]]->quantization.get();
855   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
856                                     *input_quant_params);
857 
858   // Verify output.
859   auto float_output_quant_params =
860       float_graph->tensors()->Get(op->outputs[0])->quantization();
861   auto output_quant_params =
862       subgraph->tensors[op->outputs[0]]->quantization.get();
863   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
864   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
865 
866   ASSERT_EQ(output_quant_params->scale.size(), 1);
867   ASSERT_EQ(output_quant_params->zero_point.size(), 1);
868   ASSERT_EQ(1.0f / 256.0f, output_quant_params->scale[0]);
869   ASSERT_EQ(-128, output_quant_params->zero_point[0]);
870 
871   // check op and versioning.
872   EXPECT_EQ(model_.operator_codes.size(), 1);
873   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
874             BuiltinOperator_SOFTMAX);
875   EXPECT_EQ(model_.operator_codes[0]->version, 2);
876 }
877 
878 class QuantizeAvgPoolTest : public QuantizeModelTest {
879  protected:
QuantizeAvgPoolTest()880   QuantizeAvgPoolTest() {
881     input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5);
882     readonly_model_ = input_model_->GetModel();
883     readonly_model_->UnPackTo(&model_);
884   }
885 };
886 
TEST_F(QuantizeAvgPoolTest,VerifyAvgPoolQuantization)887 TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
888   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
889                                           TensorType_INT8, false,
890                                           TensorType_INT8, &error_reporter_);
891   ASSERT_EQ(kTfLiteOk, status);
892 
893   const auto& subgraph = model_.subgraphs[0];
894   auto op = subgraph->operators[0].get();
895   // Model has a single AveragePool op.
896   ASSERT_EQ(op->opcode_index, 0);
897   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
898             BuiltinOperator_AVERAGE_POOL_2D);
899 
900   ASSERT_EQ(op->inputs.size(), 1);
901   ASSERT_EQ(op->outputs.size(), 1);
902 
903   auto float_graph = readonly_model_->subgraphs()->Get(0);
904   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
905             TensorType_FLOAT32);
906   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
907             TensorType_FLOAT32);
908 
909   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
910   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
911 
912   auto float_input_quant_params =
913       float_graph->tensors()->Get(op->inputs[0])->quantization();
914   auto input_quant_params =
915       subgraph->tensors[op->inputs[0]]->quantization.get();
916   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
917                                     *input_quant_params);
918 
919   auto float_output_quant_params =
920       float_graph->tensors()->Get(op->outputs[0])->quantization();
921   auto output_quant_params =
922       subgraph->tensors[op->outputs[0]]->quantization.get();
923   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
924   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
925   ASSERT_EQ(output_quant_params->min.size(), 1);
926   ASSERT_EQ(output_quant_params->max.size(), 1);
927 
928   // Make sure the input min/maxes are propagated to outputs.
929   EXPECT_EQ(input_quant_params->min[0], output_quant_params->min[0]);
930   EXPECT_EQ(input_quant_params->max[0], output_quant_params->max[0]);
931   EXPECT_EQ(input_quant_params->scale[0], output_quant_params->scale[0]);
932 
933   // check op and versioning.
934   EXPECT_EQ(model_.operator_codes.size(), 1);
935   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
936             BuiltinOperator_AVERAGE_POOL_2D);
937   EXPECT_EQ(model_.operator_codes[0]->version, 2);
938 }
939 
940 class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
941  protected:
QuantizeMultiInputAddWithReshapeTest()942   QuantizeMultiInputAddWithReshapeTest() {
943     input_model_ = ReadModel(internal::kMultiInputAddWithReshape);
944     readonly_model_ = input_model_->GetModel();
945     readonly_model_->UnPackTo(&model_);
946   }
947 };
948 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyReshapeQuantization)949 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
950   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
951                                           TensorType_INT8, false,
952                                           TensorType_INT8, &error_reporter_);
953   ASSERT_EQ(kTfLiteOk, status);
954 
955   // Verify Reshape is quantized.
956   const auto& subgraph = model_.subgraphs[0];
957   auto op = subgraph->operators[1].get();
958   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
959             BuiltinOperator_RESHAPE);
960 
961   ASSERT_EQ(op->inputs.size(), 2);
962   ASSERT_EQ(op->outputs.size(), 1);
963 
964   auto float_graph = readonly_model_->subgraphs()->Get(0);
965   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
966             TensorType_FLOAT32);
967   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
968             TensorType_FLOAT32);
969 
970   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
971   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
972   auto float_input_quant_params =
973       float_graph->tensors()->Get(op->inputs[0])->quantization();
974   auto input_quant_params =
975       subgraph->tensors[op->inputs[0]]->quantization.get();
976   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
977                                     *input_quant_params);
978 
979   auto float_output_quant_params =
980       float_graph->tensors()->Get(op->outputs[0])->quantization();
981   auto output_quant_params =
982       subgraph->tensors[op->outputs[0]]->quantization.get();
983   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
984   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
985   ASSERT_EQ(output_quant_params->min.size(), 1);
986   ASSERT_EQ(output_quant_params->max.size(), 1);
987 
988   // check op and versioning.
989   EXPECT_EQ(model_.operator_codes.size(), 2);
990   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
991             BuiltinOperator_ADD);
992   EXPECT_EQ(model_.operator_codes[0]->version, 2);
993   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
994             BuiltinOperator_RESHAPE);
995   EXPECT_EQ(model_.operator_codes[1]->version, 1);
996 }
997 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyAddQuantization)998 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
999   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
1000                                           TensorType_INT8, false,
1001                                           TensorType_INT8, &error_reporter_);
1002   ASSERT_EQ(kTfLiteOk, status);
1003 
1004   // Verify ADD is quantized.
1005   const auto& subgraph = model_.subgraphs[0];
1006   auto op = subgraph->operators[0].get();
1007   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1008             BuiltinOperator_ADD);
1009 
1010   ASSERT_EQ(op->inputs.size(), 2);
1011   ASSERT_EQ(op->outputs.size(), 1);
1012 
1013   auto float_graph = readonly_model_->subgraphs()->Get(0);
1014   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1015             TensorType_FLOAT32);
1016   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1017             TensorType_FLOAT32);
1018   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1019             TensorType_FLOAT32);
1020 
1021   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
1022     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
1023               TensorType_INT8);
1024     auto float_input_quant_params =
1025         float_graph->tensors()->Get(op->inputs[input_idx])->quantization();
1026     auto input_quant_params =
1027         subgraph->tensors[op->inputs[input_idx]]->quantization.get();
1028     VerifyAsymmetricQuantizationScale(*float_input_quant_params,
1029                                       *input_quant_params);
1030   }
1031 
1032   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1033   auto float_output_quant_params =
1034       float_graph->tensors()->Get(op->outputs[0])->quantization();
1035   auto output_quant_params =
1036       subgraph->tensors[op->outputs[0]]->quantization.get();
1037   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
1038   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
1039   ASSERT_EQ(output_quant_params->min.size(), 1);
1040   ASSERT_EQ(output_quant_params->max.size(), 1);
1041 
1042   // check op and versioning.
1043   EXPECT_EQ(model_.operator_codes.size(), 2);
1044   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1045             BuiltinOperator_ADD);
1046   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1047   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1048             BuiltinOperator_RESHAPE);
1049   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1050 }
1051 
1052 class QuantizeConstInputTest : public QuantizeModelTest,
1053                                public testing::WithParamInterface<TensorType> {
1054  protected:
QuantizeConstInputTest()1055   QuantizeConstInputTest() {
1056     tensor_type_ = GetParam();
1057     input_model_ = ReadModel(internal::kConstInputAddModel);
1058     readonly_model_ = input_model_->GetModel();
1059     readonly_model_->UnPackTo(&model_);
1060   }
1061 
1062   TensorType tensor_type_;
1063 };
1064 INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest,
1065                          testing::ValuesIn({TensorType_INT8,
1066                                             TensorType_INT16}));
1067 
TEST_P(QuantizeConstInputTest,VerifyConstOpInput)1068 TEST_P(QuantizeConstInputTest, VerifyConstOpInput) {
1069   auto status =
1070       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
1071                                 false, tensor_type_, &error_reporter_);
1072   ASSERT_EQ(kTfLiteOk, status);
1073 
1074   // Verify ConstOp is quantized.
1075   const auto& subgraph = model_.subgraphs[0];
1076   auto op = subgraph->operators[0].get();
1077   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1078             BuiltinOperator_ADD);
1079 
1080   ASSERT_EQ(op->inputs.size(), 2);
1081   ASSERT_EQ(op->outputs.size(), 1);
1082 
1083   auto float_graph = readonly_model_->subgraphs()->Get(0);
1084   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1085             TensorType_FLOAT32);
1086   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1087             TensorType_FLOAT32);
1088 
1089   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
1090     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
1091               tensor_type_);
1092   }
1093 
1094   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, tensor_type_);
1095 
1096   // check op and versioning.
1097   EXPECT_EQ(model_.operator_codes.size(), 1);
1098   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1099             BuiltinOperator_ADD);
1100   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1101 
1102   // check that in case of int16 activations, pot_scale_int16 parameter is set
1103   // to false.
1104   if (tensor_type_ == TensorType_INT16) {
1105     EXPECT_EQ(subgraph->operators[0]
1106                   .get()
1107                   ->builtin_options.AsAddOptions()
1108                   ->pot_scale_int16,
1109               false);
1110   }
1111 }
1112 class QuantizeArgMaxTest : public QuantizeModelTest {
1113  protected:
QuantizeArgMaxTest()1114   QuantizeArgMaxTest() {
1115     input_model_ = ReadModel(internal::kModelWithArgMaxOp);
1116     readonly_model_ = input_model_->GetModel();
1117     readonly_model_->UnPackTo(&model_);
1118   }
1119 };
1120 
TEST_F(QuantizeArgMaxTest,VerifyArgMax)1121 TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
1122   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
1123                                           TensorType_INT8, false,
1124                                           TensorType_INT8, &error_reporter_);
1125   ASSERT_EQ(kTfLiteOk, status);
1126 
1127   const auto& subgraph = model_.subgraphs[0];
1128   auto op = subgraph->operators[0].get();
1129   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1130             BuiltinOperator_ARG_MAX);
1131 
1132   ASSERT_EQ(op->inputs.size(), 2);
1133   ASSERT_EQ(op->outputs.size(), 1);
1134 
1135   auto float_graph = readonly_model_->subgraphs()->Get(0);
1136   // Verify ArgMax input is quantized.
1137   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1138             TensorType_FLOAT32);
1139   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1140 
1141   // Verify ArgMax input axis should still be the same type.
1142   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1143             subgraph->tensors[op->inputs[1]].get()->type);
1144 
1145   // The output of ArgMax should still be the same type.
1146   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1147             subgraph->tensors[op->outputs[0]].get()->type);
1148 
1149   // check op and versioning.
1150   EXPECT_EQ(model_.operator_codes.size(), 1);
1151   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1152             BuiltinOperator_ARG_MAX);
1153   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1154 }
1155 
1156 class QuantizeLSTMTest : public QuantizeModelTest {
1157  protected:
QuantizeLSTMTest()1158   QuantizeLSTMTest() {
1159     input_model_ = ReadModel(internal::kLstmCalibrated);
1160     readonly_model_ = input_model_->GetModel();
1161     readonly_model_->UnPackTo(&model_);
1162   }
1163 };
1164 
TEST_F(QuantizeLSTMTest,VerifyLSTM)1165 TEST_F(QuantizeLSTMTest, VerifyLSTM) {
1166   // Quantize model.
1167   auto status = QuantizeModelAllOperators(
1168       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1169       TensorType_INT8, &error_reporter_);
1170   ASSERT_EQ(kTfLiteOk, status);
1171 
1172   // Read expected model.
1173   auto expected_fb_model = ReadModel(internal::kLstmQuantized);
1174   auto expected_read_only_model = expected_fb_model->GetModel();
1175   ModelT expected_model;
1176   expected_read_only_model->UnPackTo(&expected_model);
1177 
1178   ExpectSameModels(model_, expected_model);
1179 }
1180 
1181 class QuantizeLSTM2Test : public QuantizeModelTest {
1182  protected:
QuantizeLSTM2Test()1183   QuantizeLSTM2Test() {
1184     input_model_ = ReadModel(internal::kLstmCalibrated2);
1185     readonly_model_ = input_model_->GetModel();
1186     readonly_model_->UnPackTo(&model_);
1187   }
1188 };
1189 
TEST_F(QuantizeLSTM2Test,VerifyLSTM)1190 TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
1191   // Quantize model.
1192   auto status = QuantizeModelAllOperators(
1193       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1194       TensorType_INT8, &error_reporter_);
1195   ASSERT_EQ(kTfLiteOk, status);
1196 
1197   // Read expected model.
1198   auto expected_fb_model = ReadModel(internal::kLstmQuantized2);
1199   auto expected_read_only_model = expected_fb_model->GetModel();
1200   ModelT expected_model;
1201   expected_read_only_model->UnPackTo(&expected_model);
1202 
1203   ExpectSameModels(model_, expected_model);
1204 }
1205 
1206 class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest {
1207  protected:
QuantizeUnidirectionalSequenceLSTMTest()1208   QuantizeUnidirectionalSequenceLSTMTest() {
1209     input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated);
1210     readonly_model_ = input_model_->GetModel();
1211     readonly_model_->UnPackTo(&model_);
1212   }
1213 };
1214 
TEST_F(QuantizeUnidirectionalSequenceLSTMTest,VerifyUnidirectionalSequenceLSTM)1215 TEST_F(QuantizeUnidirectionalSequenceLSTMTest,
1216        VerifyUnidirectionalSequenceLSTM) {
1217   // Quantize model.
1218   auto status = QuantizeModelAllOperators(
1219       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1220       TensorType_INT8, &error_reporter_);
1221   ASSERT_EQ(kTfLiteOk, status);
1222 
1223   // Read expected model.
1224   auto expected_fb_model =
1225       ReadModel(internal::kUnidirectionalSequenceLstmQuantized);
1226   auto expected_read_only_model = expected_fb_model->GetModel();
1227   ModelT expected_model;
1228   expected_read_only_model->UnPackTo(&expected_model);
1229 
1230   ExpectSameModels(model_, expected_model);
1231 }
1232 
1233 class QuantizeSVDFTest : public QuantizeModelTest {
1234  protected:
QuantizeSVDFTest()1235   QuantizeSVDFTest() {
1236     input_model_ = ReadModel(internal::kSvdfCalibrated);
1237     readonly_model_ = input_model_->GetModel();
1238     readonly_model_->UnPackTo(&model_);
1239   }
1240 };
1241 
TEST_F(QuantizeSVDFTest,VerifySVDF)1242 TEST_F(QuantizeSVDFTest, VerifySVDF) {
1243   // Quantize model.
1244   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
1245                                           TensorType_INT8, false,
1246                                           TensorType_INT8, &error_reporter_);
1247   ASSERT_EQ(kTfLiteOk, status);
1248 
1249   // Read expected model.
1250   auto expected_fb_model = ReadModel(internal::kSvdfQuantized);
1251   auto expected_read_only_model = expected_fb_model->GetModel();
1252   ModelT expected_model;
1253   expected_read_only_model->UnPackTo(&expected_model);
1254 
1255   // Comparison.
1256   ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size());
1257   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
1258        subgraph_idx++) {
1259     const auto graph = model_.subgraphs[subgraph_idx].get();
1260     const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
1261     ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size());
1262     for (size_t i = 0; i < graph->tensors.size(); i++) {
1263       const auto tensor = graph->tensors[i].get();
1264       const auto expected_tensor = expected_graph->tensors[i].get();
1265       EXPECT_EQ(tensor->buffer, expected_tensor->buffer);
1266       EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
1267       EXPECT_EQ(tensor->shape, expected_tensor->shape);
1268       EXPECT_EQ(tensor->name, expected_tensor->name);
1269       EXPECT_EQ(tensor->type, expected_tensor->type);
1270       const auto quantization_params = tensor->quantization.get();
1271       const auto expected_quantization_params =
1272           expected_tensor->quantization.get();
1273       if (quantization_params != nullptr ||
1274           expected_quantization_params != nullptr) {
1275         EXPECT_NE(quantization_params, nullptr);
1276         EXPECT_NE(expected_quantization_params, nullptr);
1277         EXPECT_EQ(quantization_params->scale,
1278                   expected_quantization_params->scale);
1279         EXPECT_EQ(quantization_params->zero_point,
1280                   expected_quantization_params->zero_point);
1281       }
1282     }
1283   }
1284   ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size());
1285   for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size();
1286        ++buffer_idx) {
1287     const auto buffer = model_.buffers[buffer_idx].get()->data;
1288     const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data;
1289     EXPECT_EQ(buffer, expected_buffer);
1290   }
1291 }
1292 
1293 class QuantizeFCTest : public QuantizeModelTest {
1294  protected:
QuantizeFCTest()1295   QuantizeFCTest() {
1296     input_model_ = ReadModel(internal::kModelWithFCOp);
1297     readonly_model_ = input_model_->GetModel();
1298     readonly_model_->UnPackTo(&model_);
1299   }
1300 };
1301 
TEST_F(QuantizeFCTest,VerifyFC)1302 TEST_F(QuantizeFCTest, VerifyFC) {
1303   auto status = QuantizeModelAllOperators(&builder_, &model_, TensorType_INT8,
1304                                           TensorType_INT8, false,
1305                                           TensorType_INT8, &error_reporter_);
1306   ASSERT_EQ(kTfLiteOk, status);
1307 
1308   const auto& subgraph = model_.subgraphs[0];
1309   auto op = subgraph->operators[0].get();
1310   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1311             BuiltinOperator_FULLY_CONNECTED);
1312 
1313   ASSERT_EQ(op->inputs.size(), 3);
1314   ASSERT_EQ(op->outputs.size(), 1);
1315 
1316   auto float_graph = readonly_model_->subgraphs()->Get(0);
1317   // Verify FC input and weight is quantized.
1318   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1319             TensorType_FLOAT32);
1320   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1321   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1322             TensorType_FLOAT32);
1323   EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1324 
1325   // Verify FC bias should be int32 quantized.
1326   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[2])->type(),
1327             TensorType_FLOAT32);
1328   EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1329 
1330   // The output of FC should be quantized.
1331   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1332             TensorType_FLOAT32);
1333   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1334 
1335   // check op and versioning.
1336   EXPECT_EQ(model_.operator_codes.size(), 2);
1337   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1338             BuiltinOperator_FULLY_CONNECTED);
1339   EXPECT_EQ(model_.operator_codes[0]->version, 4);
1340   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1341             BuiltinOperator_RESHAPE);
1342   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1343 }
1344 
1345 class QuantizeCustomOpTest
1346     : public QuantizeModelTest,
1347       public ::testing::WithParamInterface<tflite::TensorType> {
1348  protected:
QuantizeCustomOpTest()1349   QuantizeCustomOpTest() {
1350     input_model_ = ReadModel(internal::kModelMixed);
1351     readonly_model_ = input_model_->GetModel();
1352     readonly_model_->UnPackTo(&model_);
1353   }
1354 };
1355 
TEST_P(QuantizeCustomOpTest,VerifyMixedQuantization)1356 TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
1357   auto status = QuantizeModelAllOperators(
1358       &builder_, &model_, GetParam(), GetParam(),
1359       /*allow_float=*/true, GetParam(), &error_reporter_);
1360   ASSERT_EQ(kTfLiteOk, status);
1361   const auto& subgraph = model_.subgraphs[0];
1362   auto float_graph = readonly_model_->subgraphs()->Get(0);
1363   // The original model reshape->custom->custom->squeeze.
1364   ASSERT_EQ(float_graph->operators()->size(), 4);
1365   // The resulting model should be:
1366   // reshape->dequantize->custom->custom->quantize->squeeze.
1367   ASSERT_EQ(subgraph->operators.size(), 6);
1368   const std::vector<BuiltinOperator> op_codes = {
1369       BuiltinOperator_RESHAPE,  BuiltinOperator_DEQUANTIZE,
1370       BuiltinOperator_CUSTOM,   BuiltinOperator_CUSTOM,
1371       BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
1372   const std::vector<TensorType> op_input_types = {
1373       GetParam(),         GetParam(),         TensorType_FLOAT32,
1374       TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
1375   for (int i = 0; i < subgraph->operators.size(); ++i) {
1376     OperatorT* op = subgraph->operators[i].get();
1377     ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1378               op_codes[i]);
1379     ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1380   }
1381 }
1382 
1383 INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest,
1384                          ::testing::Values(TensorType_INT8, TensorType_INT16));
1385 
1386 class QuantizeOp16x8Test : public QuantizeModelTest {
1387  protected:
QuantizeOp16x8Test()1388   QuantizeOp16x8Test() {
1389     input_model_ = ReadModel(internal::kModelMixed16x8);
1390     readonly_model_ = input_model_->GetModel();
1391     readonly_model_->UnPackTo(&model_);
1392   }
1393 };
1394 
TEST_F(QuantizeOp16x8Test,VerifyMixedQuantization16x8)1395 TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
1396   auto status = QuantizeModelAllOperators(
1397       &builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
1398       /*allow_float=*/true, TensorType_INT16, &error_reporter_);
1399   ASSERT_EQ(kTfLiteOk, status);
1400   const auto& subgraph = model_.subgraphs[0];
1401   auto float_graph = readonly_model_->subgraphs()->Get(0);
1402   // The original model conv_2d->log_softmax
1403   ASSERT_EQ(float_graph->operators()->size(), 2);
1404   // The resulting model should be:
1405   // conv_2d->dequantize->log_softmax
1406   ASSERT_EQ(subgraph->operators.size(), 3);
1407   const std::vector<BuiltinOperator> op_codes = {BuiltinOperator_CONV_2D,
1408                                                  BuiltinOperator_DEQUANTIZE,
1409                                                  BuiltinOperator_LOG_SOFTMAX};
1410   const std::vector<TensorType> op_input_types = {
1411       TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
1412   for (int i = 0; i < subgraph->operators.size(); ++i) {
1413     OperatorT* op = subgraph->operators[i].get();
1414     ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1415               op_codes[i]);
1416     ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1417   }
1418 }
1419 
1420 class QuantizePackTest : public QuantizeModelTest {
1421  protected:
QuantizePackTest()1422   QuantizePackTest() {
1423     input_model_ = ReadModel(internal::kModelPack);
1424     readonly_model_ = input_model_->GetModel();
1425     readonly_model_->UnPackTo(&model_);
1426   }
1427 };
1428 
TEST_F(QuantizePackTest,VerifyPack)1429 TEST_F(QuantizePackTest, VerifyPack) {
1430   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1431 
1432   ASSERT_EQ(kTfLiteOk, status);
1433 
1434   const auto subgraph = model_.subgraphs[0].get();
1435 
1436   // The model should only have 3 inputs and 1 output.
1437   EXPECT_EQ(subgraph->inputs.size(), 3);
1438   EXPECT_EQ(subgraph->outputs.size(), 1);
1439 
1440   const auto& op1 = subgraph->operators[1].get();
1441   const auto& op2 = subgraph->operators[2].get();
1442   const auto& op3 = subgraph->operators[3].get();
1443   const auto& op4 = subgraph->operators[4].get();
1444 
1445   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op1->opcode_index].get()),
1446             BuiltinOperator_QUANTIZE);
1447   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op2->opcode_index].get()),
1448             BuiltinOperator_QUANTIZE);
1449   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op3->opcode_index].get()),
1450             BuiltinOperator_PACK);
1451   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op4->opcode_index].get()),
1452             BuiltinOperator_DEQUANTIZE);
1453 
1454   const auto& pack_input0 = subgraph->tensors[op3->inputs[0]].get();
1455   const auto& pack_input1 = subgraph->tensors[op3->inputs[1]].get();
1456   const auto& pack_input2 = subgraph->tensors[op3->inputs[2]].get();
1457 
1458   const auto& pack_output = subgraph->tensors[op3->outputs[0]].get();
1459 
1460   // Check quantization parameters for input and output.
1461   EXPECT_FLOAT_EQ(pack_input0->quantization->scale[0],
1462                   pack_input1->quantization->scale[0]);
1463   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1464                   pack_input2->quantization->scale[0]);
1465   EXPECT_FLOAT_EQ(pack_input0->quantization->zero_point[0],
1466                   pack_input1->quantization->zero_point[0]);
1467   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1468                   pack_input2->quantization->zero_point[0]);
1469 
1470   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1471                   pack_output->quantization->scale[0]);
1472   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1473                   pack_output->quantization->zero_point[0]);
1474 
1475   // Check type of input and output.
1476   EXPECT_EQ(pack_output->type, TensorType_INT8);
1477   EXPECT_EQ(pack_input0->type, TensorType_INT8);
1478   EXPECT_EQ(pack_input1->type, TensorType_INT8);
1479   EXPECT_EQ(pack_input2->type, TensorType_INT8);
1480 }
1481 
1482 class QuantizeMinimumMaximumTest
1483     : public QuantizeModelTest,
1484       public testing::WithParamInterface<const char*> {
1485  protected:
QuantizeMinimumMaximumTest()1486   QuantizeMinimumMaximumTest() {
1487     input_model_ = ReadModel(GetParam());
1488     readonly_model_ = input_model_->GetModel();
1489     readonly_model_->UnPackTo(&model_);
1490   }
1491 };
1492 
TEST_P(QuantizeMinimumMaximumTest,VerifyMinimumMaximum)1493 TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
1494   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1495   ASSERT_EQ(kTfLiteOk, status);
1496   const auto& subgraph = model_.subgraphs[0];
1497 
1498   // Check that the first op is Quantize and the last is Dequant.
1499   const auto& quant_op = subgraph->operators[0];
1500   const auto& dequant_op = subgraph->operators[subgraph->operators.size() - 1];
1501   const int32_t quant_idx = quant_op->opcode_index;
1502   const int32_t dequant_idx = dequant_op->opcode_index;
1503   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
1504             BuiltinOperator_QUANTIZE);
1505   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
1506             BuiltinOperator_DEQUANTIZE);
1507   const auto& requant1 = subgraph->operators[1].get();
1508   // Check that we have RE operator.
1509   auto requant1_builtin_code =
1510       GetBuiltinCode(model_.operator_codes[requant1->opcode_index].get());
1511   ASSERT_TRUE(requant1_builtin_code == tflite::BuiltinOperator_QUANTIZE);
1512 
1513   const auto& requant2 = subgraph->operators[2].get();
1514   // Check that we have RE operator.
1515   auto requant2_builtin_code =
1516       GetBuiltinCode(model_.operator_codes[requant2->opcode_index].get());
1517   ASSERT_TRUE(requant2_builtin_code == tflite::BuiltinOperator_QUANTIZE);
1518 
1519   const auto& op = subgraph->operators[3].get();
1520 
1521   // Check that we have MINIMUM or MAXIMUM operator.
1522   auto op_builtin_code =
1523       GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
1524   ASSERT_TRUE(op_builtin_code == tflite::BuiltinOperator_MINIMUM ||
1525               op_builtin_code == tflite::BuiltinOperator_MAXIMUM);
1526 
1527   // Check that we have two inputs and one output.
1528   ASSERT_EQ(op->inputs.size(), 2);
1529   ASSERT_EQ(op->outputs.size(), 1);
1530 
1531   // Check that all is quantized.
1532   auto output = subgraph->tensors[op->outputs[0]].get();
1533   auto input1 = subgraph->tensors[op->outputs[0]].get();
1534   auto input2 = subgraph->tensors[op->outputs[0]].get();
1535 
1536   EXPECT_EQ(output->type, TensorType_INT8);
1537   EXPECT_EQ(input1->type, TensorType_INT8);
1538   EXPECT_EQ(input2->type, TensorType_INT8);
1539 
1540   // Check if the quantization params of the minimum/maximum inputs match
1541   // after requantization
1542   EXPECT_EQ(input1->quantization->scale, input2->quantization->scale);
1543   EXPECT_EQ(input1->quantization->zero_point, input2->quantization->zero_point);
1544 
1545   // Check the input quantization params match the output ones.
1546   EXPECT_EQ(output->quantization->scale, input1->quantization->scale);
1547   EXPECT_EQ(output->quantization->zero_point, input1->quantization->zero_point);
1548   EXPECT_EQ(output->quantization->scale, input2->quantization->scale);
1549   EXPECT_EQ(output->quantization->zero_point, input2->quantization->zero_point);
1550 
1551   EXPECT_EQ(subgraph->tensors.size(), 7);
1552 
1553   EXPECT_EQ(subgraph->tensors[0]->name, "input_int8");
1554   EXPECT_EQ(subgraph->tensors[1]->name, "output_int8");
1555   EXPECT_EQ(subgraph->tensors[2]->name, "output/y");
1556   EXPECT_EQ(subgraph->tensors[3]->name, "input_requantized");
1557   EXPECT_EQ(subgraph->tensors[4]->name, "output/y_requantized");
1558   EXPECT_EQ(subgraph->tensors[5]->name, "input");
1559   EXPECT_EQ(subgraph->tensors[6]->name, "output");
1560 }
1561 
1562 INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest,
1563                          testing::ValuesIn({internal::kModelWithMinimumOp,
1564                                             internal::kModelWithMaximumOp}));
1565 
1566 class QuantizeUnpackTest : public QuantizeModelTest {
1567  protected:
QuantizeUnpackTest()1568   QuantizeUnpackTest() {
1569     input_model_ = ReadModel(internal::kModelWithUnpack);
1570     readonly_model_ = input_model_->GetModel();
1571     readonly_model_->UnPackTo(&model_);
1572   }
1573 };
TEST_F(QuantizeUnpackTest,VerifyUnpack)1574 TEST_F(QuantizeUnpackTest, VerifyUnpack) {
1575   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1576 
1577   ASSERT_EQ(kTfLiteOk, status);
1578 
1579   const auto subgraph = model_.subgraphs[0].get();
1580   auto op = subgraph->operators[1].get();
1581 
1582   auto float_graph = readonly_model_->subgraphs()->Get(0);
1583 
1584   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1585             BuiltinOperator_UNPACK);
1586 
1587   // Get unpack input and output tensors
1588   auto unpack_input = subgraph->tensors[op->inputs[0]].get();
1589   auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get();
1590   auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get();
1591 
1592   // Verify Unpack input is quantized.
1593   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1594             TensorType_FLOAT32);
1595   EXPECT_EQ(unpack_input->type, TensorType_INT8);
1596 
1597   // The model should only have one input and 2 outputs.
1598   EXPECT_EQ(subgraph->inputs.size(), 1);
1599   EXPECT_EQ(subgraph->outputs.size(), 2);
1600 
1601   // Ensure quantization parameters before and after unpack
1602   // are preserved after quantization for all outputs of
1603   // unpack.
1604   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1605                   unpack_output_0->quantization->scale[0]);
1606   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1607                   unpack_output_1->quantization->scale[0]);
1608   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1609                   unpack_output_0->quantization->zero_point[0]);
1610   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1611                   unpack_output_1->quantization->zero_point[0]);
1612 }
1613 
1614 class QuantizeTransposeTest : public QuantizeModelTest {
1615  protected:
QuantizeTransposeTest()1616   QuantizeTransposeTest() {
1617     input_model_ = ReadModel(internal::kModelWithTranspose);
1618     readonly_model_ = input_model_->GetModel();
1619     readonly_model_->UnPackTo(&model_);
1620   }
1621 };
1622 
TEST_F(QuantizeTransposeTest,VerifyTranspose)1623 TEST_F(QuantizeTransposeTest, VerifyTranspose) {
1624   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1625 
1626   ASSERT_EQ(kTfLiteOk, status);
1627 
1628   const auto subgraph = model_.subgraphs[0].get();
1629   auto op = subgraph->operators[1].get();
1630 
1631   auto float_graph = readonly_model_->subgraphs()->Get(0);
1632 
1633   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1634             BuiltinOperator_TRANSPOSE);
1635 
1636   // The model should only have one input and one outputs.
1637   EXPECT_EQ(subgraph->inputs.size(), 1);
1638   EXPECT_EQ(subgraph->outputs.size(), 1);
1639 
1640   // Get transpose input and output tensors
1641   auto transpose_input = subgraph->tensors[op->inputs[0]].get();
1642   auto transpose_output = subgraph->tensors[op->outputs[0]].get();
1643 
1644   // Verify transpose input is quantized.
1645   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1646             TensorType_FLOAT32);
1647   EXPECT_EQ(transpose_input->type, TensorType_INT8);
1648 
1649   // Ensure quantization parameters before and after transpose
1650   // are preserved after quantization for all outputs of
1651   // transpose.
1652   EXPECT_FLOAT_EQ(transpose_input->quantization->scale[0],
1653                   transpose_output->quantization->scale[0]);
1654   EXPECT_EQ(transpose_input->quantization->zero_point[0],
1655             transpose_output->quantization->zero_point[0]);
1656 }
1657 
1658 class QuantizeQatTest : public QuantizeModelTest {
1659  protected:
QuantizeQatTest()1660   QuantizeQatTest() {
1661     input_model_ = ReadModel(internal::kQatModelWithFc);
1662     readonly_model_ = input_model_->GetModel();
1663     readonly_model_->UnPackTo(&model_);
1664   }
1665 };
1666 
TEST_F(QuantizeQatTest,VerifySingleQuantize)1667 TEST_F(QuantizeQatTest, VerifySingleQuantize) {
1668   auto status = QuantizeModelAllOperators(
1669       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1670       TensorType_INT8, &error_reporter_);
1671   ASSERT_EQ(kTfLiteOk, status);
1672 
1673   const auto& subgraph = model_.subgraphs[0];
1674   auto op = subgraph->operators[0].get();
1675   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1676             BuiltinOperator_QUANTIZE);
1677   op = subgraph->operators[1].get();
1678   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1679             BuiltinOperator_RESHAPE);
1680   op = subgraph->operators[2].get();
1681   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1682             BuiltinOperator_FULLY_CONNECTED);
1683 
1684   ASSERT_EQ(op->inputs.size(), 3);
1685   ASSERT_EQ(op->outputs.size(), 1);
1686 
1687   auto qat_graph = readonly_model_->subgraphs()->Get(0);
1688   // Verify FC input and weight is quantized.
1689   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[0])->type(), TensorType_INT8);
1690   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1691   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[1])->type(), TensorType_INT8);
1692   EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1693 
1694   // Verify FC bias should be int32 quantized.
1695   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[2])->type(), TensorType_INT32);
1696   EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1697 
1698   // The output of FC should be quantized.
1699   ASSERT_EQ(qat_graph->tensors()->Get(op->outputs[0])->type(), TensorType_INT8);
1700   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1701 
1702   // check op and versioning.
1703   EXPECT_EQ(model_.operator_codes.size(), 4);
1704   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1705             BuiltinOperator_QUANTIZE);
1706   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1707             BuiltinOperator_RESHAPE);
1708   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[2].get()),
1709             BuiltinOperator_FULLY_CONNECTED);
1710   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[3].get()),
1711             BuiltinOperator_DEQUANTIZE);
1712   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1713   EXPECT_EQ(model_.operator_codes[2]->version, 4);
1714 }
1715 
1716 class QuantizeBroadcastToModelTest
1717     : public QuantizeModelTest,
1718       public testing::WithParamInterface<TensorType> {
1719  protected:
QuantizeBroadcastToModelTest()1720   QuantizeBroadcastToModelTest() {
1721     tensor_type_ = GetParam();
1722     input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
1723     readonly_model_ = input_model_->GetModel();
1724     readonly_model_->UnPackTo(&model_);
1725   }
1726   TensorType tensor_type_;
1727 };
1728 
1729 INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
1730                          QuantizeBroadcastToModelTest,
1731                          testing::ValuesIn({TensorType_INT8,
1732                                             TensorType_INT16}));
1733 
TEST_P(QuantizeBroadcastToModelTest,VerifyBroadcastToQuantization)1734 TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
1735   auto status =
1736       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
1737                                 false, tensor_type_, &error_reporter_);
1738   EXPECT_EQ(status, kTfLiteOk);
1739 
1740   // There is only one subgraph.
1741   const int32_t subgraph_idx = 0;
1742   const auto& subgraph = model_.subgraphs[subgraph_idx];
1743   const auto& readonly_subgraph =
1744       readonly_model_->subgraphs()->Get(subgraph_idx);
1745 
1746   // There should be a single broadcast_to op.
1747   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1748   EXPECT_EQ(subgraph->operators.size(), 1);
1749   const auto& broadcast_to = subgraph->operators[0];
1750   EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
1751             BuiltinOperator_BROADCAST_TO);
1752 
1753   // There should be 3 tensors: input, output, and BroadcastTo/shape.
1754   EXPECT_EQ(subgraph->tensors.size(), 3);
1755 
1756   // Input Tensor
1757   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1758   EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
1759   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1760   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1761 
1762   // Output Tensor. The name given in the generated
1763   // .bin test file is 'Identity' and should be preserved
1764   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1765   EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
1766   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1767   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1768 
1769   // The BroadCastTo shape is of type INT32 and should not be quantized
1770   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1771   EXPECT_EQ(subgraph->tensors[1]->name,
1772             "model/tf.broadcast_to/BroadcastTo/shape");
1773   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1774   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1775 
1776   // check op and versioning.
1777   EXPECT_EQ(model_.operator_codes.size(), 1);
1778   EXPECT_EQ(model_.operator_codes[0]->builtin_code,
1779             BuiltinOperator_BROADCAST_TO);
1780   EXPECT_EQ(model_.operator_codes[0]->version, 3);
1781 }
1782 
1783 class QuantizeGatherNDModelTest
1784     : public QuantizeModelTest,
1785       public testing::WithParamInterface<TensorType> {
1786  protected:
QuantizeGatherNDModelTest()1787   QuantizeGatherNDModelTest() {
1788     tensor_type_ = GetParam();
1789     input_model_ = ReadModel(internal::kModelWithGatherNDOp);
1790     readonly_model_ = input_model_->GetModel();
1791     readonly_model_->UnPackTo(&model_);
1792   }
1793 
1794   TensorType tensor_type_;
1795 };
1796 
1797 INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
1798                          QuantizeGatherNDModelTest,
1799                          testing::ValuesIn({TensorType_INT8,
1800                                             TensorType_INT16}));
1801 
TEST_P(QuantizeGatherNDModelTest,QuantizeGatherND)1802 TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
1803   auto status =
1804       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
1805                                 false, tensor_type_, &error_reporter_);
1806   EXPECT_EQ(status, kTfLiteOk);
1807 
1808   // There is only one subgraph.
1809   const int32_t subgraph_idx = 0;
1810   const auto& subgraph = model_.subgraphs[subgraph_idx];
1811   const auto& readonly_subgraph =
1812       readonly_model_->subgraphs()->Get(subgraph_idx);
1813 
1814   // There should be a single gather_nd op.
1815   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1816   EXPECT_EQ(subgraph->operators.size(), 1);
1817   const auto& gather_nd = subgraph->operators[0];
1818   EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
1819             BuiltinOperator_GATHER_ND);
1820 
1821   // There should be 3 tensors: input, output, and indices.
1822   EXPECT_EQ(subgraph->tensors.size(), 3);
1823 
1824   // Input Tensor
1825   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1826   EXPECT_EQ(subgraph->tensors[0]->name, "input");
1827   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1828   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1829 
1830   // Output Tensor
1831   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1832   EXPECT_EQ(subgraph->tensors[2]->name, "output");
1833   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1834   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1835 
1836   // The gather indices are of type INT32 and should not be quantized
1837   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1838   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1839   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1840   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1841 
1842   // Check op and versioning.
1843   EXPECT_EQ(model_.operator_codes.size(), 1);
1844   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
1845   EXPECT_EQ(model_.operator_codes[0]->version, 3);
1846 }
1847 
1848 class QuantizeWhereModelTest : public QuantizeModelTest {
1849  protected:
QuantizeWhereModelTest()1850   QuantizeWhereModelTest() {
1851     input_model_ = ReadModel(internal::kModelWithWhereOp);
1852     readonly_model_ = input_model_->GetModel();
1853     readonly_model_->UnPackTo(&model_);
1854   }
1855 };
1856 
TEST_F(QuantizeWhereModelTest,QuantizeWhere)1857 TEST_F(QuantizeWhereModelTest, QuantizeWhere) {
1858   // Where operator takes a BOOL tensor as input
1859   // and outputs INT64 indices, both of which
1860   // should not be quantized
1861   auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL,
1862                               TensorType_INT64, &error_reporter_);
1863   EXPECT_EQ(status, kTfLiteOk);
1864 
1865   // There is only one subgraph.
1866   const int32_t subgraph_idx = 0;
1867   const auto& subgraph = model_.subgraphs[subgraph_idx];
1868   const auto& readonly_subgraph =
1869       readonly_model_->subgraphs()->Get(subgraph_idx);
1870 
1871   // There should be a single where op.
1872   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1873   EXPECT_EQ(subgraph->operators.size(), 1);
1874   const auto& where = subgraph->operators[0];
1875   EXPECT_EQ(model_.operator_codes[where->opcode_index]->builtin_code,
1876             BuiltinOperator_WHERE);
1877 
1878   // There should be 2 tensors: input and output.
1879   EXPECT_EQ(subgraph->tensors.size(), 2);
1880 
1881   // Testing input tensor type and ensuring it
1882   // was not quantized
1883   EXPECT_EQ(subgraph->tensors[0]->type, TensorType_BOOL);
1884   EXPECT_EQ(subgraph->tensors[0]->name, "input");
1885   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 0);
1886   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 0);
1887 
1888   // Testing output (indices) tensor type and ensuring it
1889   // was not quantized
1890   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT64);
1891   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1892   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1893   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1894 
1895   // check op and versioning.
1896   EXPECT_EQ(model_.operator_codes.size(), 1);
1897   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_WHERE);
1898   EXPECT_EQ(model_.operator_codes[0]->version, 1);
1899 }
1900 
1901 }  // namespace
1902 }  // namespace optimize
1903 }  // namespace tflite
1904 
main(int argc,char ** argv)1905 int main(int argc, char** argv) {
1906   tensorflow::string model_file;
1907   const std::vector<tensorflow::Flag> flag_list = {
1908       tensorflow::Flag("test_model_file", &model_file,
1909                        "Path to test tflite model file."),
1910   };
1911 
1912   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
1913   if (!parse_result) {
1914     std::cerr << "Required test_model_file\n";
1915     std::abort();
1916   }
1917   g_test_model_dir =
1918       new tensorflow::string(tensorflow::io::Dirname(model_file));
1919   ::tensorflow::port::InitMain(argv[0], &argc, &argv);
1920   return RUN_ALL_TESTS();
1921 }
1922