• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/modify_model_interface.h"
16 
17 #include <memory>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/model.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow/lite/schema/schema_utils.h"
25 
26 namespace tflite {
27 namespace optimize {
28 namespace {
29 
30 // Create a model with 1 quant, 1 FC, 1 dequant
CreateQuantizedModelSingleInputOutput(const TensorType & quantization_type)31 std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput(
32     const TensorType& quantization_type) {
33   auto model = absl::make_unique<ModelT>();
34   auto subgraph = absl::make_unique<tflite::SubGraphT>();
35   auto buffer = absl::make_unique<tflite::BufferT>();
36   auto quant_op_code = absl::make_unique<OperatorCodeT>();
37   auto quant_op = absl::make_unique<OperatorT>();
38   auto fc_op_code = absl::make_unique<OperatorCodeT>();
39   auto fc_op = absl::make_unique<OperatorT>();
40   auto dequant_op_code = absl::make_unique<OperatorCodeT>();
41   auto dequant_op = absl::make_unique<OperatorT>();
42 
43   model->subgraphs.push_back(std::move(subgraph));
44 
45   // Op code
46   quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
47   quant_op_code->deprecated_builtin_code =
48       static_cast<int8_t>(BuiltinOperator_QUANTIZE);
49   quant_op_code->version = 2;
50 
51   fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
52   fc_op_code->deprecated_builtin_code =
53       static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
54   fc_op_code->version = 2;
55 
56   dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
57   dequant_op_code->deprecated_builtin_code =
58       static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
59   dequant_op_code->version = 2;
60 
61   // Op.
62   quant_op->opcode_index = 0;
63   quant_op->inputs = {0};
64   quant_op->outputs = {1};
65 
66   fc_op->opcode_index = 1;
67   fc_op->inputs = {1};
68   fc_op->outputs = {2};
69 
70   dequant_op->opcode_index = 2;
71   dequant_op->inputs = {2};
72   dequant_op->outputs = {3};
73 
74   model->subgraphs[0]->operators.push_back(std::move(quant_op));
75   model->subgraphs[0]->operators.push_back(std::move(fc_op));
76   model->subgraphs[0]->operators.push_back(std::move(dequant_op));
77 
78   model->operator_codes.push_back(std::move(quant_op_code));
79   model->operator_codes.push_back(std::move(fc_op_code));
80   model->operator_codes.push_back(std::move(dequant_op_code));
81 
82   // Model input/output.
83   model->subgraphs[0]->inputs = {0};
84   model->subgraphs[0]->outputs = {3};
85 
86   // Tensors
87   auto tensor_0 = absl::make_unique<TensorT>();
88   tensor_0->name = "tensor_0";
89   tensor_0->shape = {};
90   tensor_0->type = TensorType_FLOAT32;
91 
92   auto tensor_1 = absl::make_unique<TensorT>();
93   tensor_1->quantization = absl::make_unique<QuantizationParametersT>();
94   tensor_1->quantization->scale.push_back(0.35);
95   tensor_1->quantization->zero_point.push_back(28);
96   tensor_1->name = "tensor_1";
97   tensor_1->shape = {};
98   tensor_1->type = quantization_type;
99 
100   auto tensor_2 = absl::make_unique<TensorT>();
101   tensor_2->quantization = absl::make_unique<QuantizationParametersT>();
102   tensor_2->quantization->scale.push_back(0.12);
103   tensor_2->quantization->zero_point.push_back(50);
104   tensor_2->name = "tensor_2";
105   tensor_2->shape = {};
106   tensor_2->type = quantization_type;
107 
108   auto tensor_3 = absl::make_unique<TensorT>();
109   tensor_3->name = "tensor_3";
110   tensor_3->shape = {};
111   tensor_3->type = TensorType_FLOAT32;
112 
113   model->subgraphs[0]->tensors.push_back(std::move(tensor_0));
114   model->subgraphs[0]->tensors.push_back(std::move(tensor_1));
115   model->subgraphs[0]->tensors.push_back(std::move(tensor_2));
116   model->subgraphs[0]->tensors.push_back(std::move(tensor_3));
117 
118   // Buffer
119   model->buffers.push_back(std::move(buffer));
120 
121   return model;
122 }
123 
124 // Create a model with 2 quant, 1 FC, 2 dequant
125 // The model mimics the behavior of the quantize_model.cc.
CreateQuantizedModelMultipleInputOutput(const TensorType & quantization_type)126 std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput(
127     const TensorType& quantization_type) {
128   auto model = absl::make_unique<ModelT>();
129   auto subgraph = absl::make_unique<tflite::SubGraphT>();
130   auto buffer = absl::make_unique<tflite::BufferT>();
131   auto quant_op_code = absl::make_unique<OperatorCodeT>();
132   auto quant_op_1 = absl::make_unique<OperatorT>();
133   auto quant_op_2 = absl::make_unique<OperatorT>();
134   auto fc_op_code = absl::make_unique<OperatorCodeT>();
135   auto fc_op = absl::make_unique<OperatorT>();
136   auto dequant_op_code = absl::make_unique<OperatorCodeT>();
137   auto dequant_op_1 = absl::make_unique<OperatorT>();
138   auto dequant_op_2 = absl::make_unique<OperatorT>();
139 
140   model->subgraphs.push_back(std::move(subgraph));
141 
142   // Op code
143   quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
144   quant_op_code->deprecated_builtin_code =
145       static_cast<int8_t>(BuiltinOperator_QUANTIZE);
146   quant_op_code->version = 2;
147 
148   fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
149   fc_op_code->deprecated_builtin_code =
150       static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
151   fc_op_code->version = 2;
152 
153   dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
154   dequant_op_code->deprecated_builtin_code =
155       static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
156   dequant_op_code->version = 2;
157 
158   // Op.
159   quant_op_1->opcode_index = 0;
160   quant_op_1->inputs = {0};
161   quant_op_1->outputs = {2};
162   quant_op_2->opcode_index = 0;
163   quant_op_2->inputs = {1};
164   quant_op_2->outputs = {3};
165 
166   fc_op->opcode_index = 1;
167   fc_op->inputs = {2, 3};
168   fc_op->outputs = {4, 5};
169 
170   dequant_op_1->opcode_index = 2;
171   dequant_op_1->inputs = {4};
172   dequant_op_1->outputs = {6};
173   dequant_op_2->opcode_index = 2;
174   dequant_op_2->inputs = {5};
175   dequant_op_2->outputs = {7};
176 
177   model->subgraphs[0]->operators.push_back(std::move(quant_op_1));
178   model->subgraphs[0]->operators.push_back(std::move(quant_op_2));
179   model->subgraphs[0]->operators.push_back(std::move(fc_op));
180   model->subgraphs[0]->operators.push_back(std::move(dequant_op_1));
181   model->subgraphs[0]->operators.push_back(std::move(dequant_op_2));
182 
183   model->operator_codes.push_back(std::move(quant_op_code));
184   model->operator_codes.push_back(std::move(fc_op_code));
185   model->operator_codes.push_back(std::move(dequant_op_code));
186 
187   // Model input/output.
188   model->subgraphs[0]->inputs = {0, 1};
189   model->subgraphs[0]->outputs = {6, 7};
190 
191   // Tensors
192   auto tensor_0 = absl::make_unique<TensorT>();
193   tensor_0->name = "tensor_0";
194   tensor_0->shape = {};
195   tensor_0->type = TensorType_FLOAT32;
196 
197   auto tensor_1 = absl::make_unique<TensorT>();
198   tensor_1->name = "tensor_1";
199   tensor_1->shape = {};
200   tensor_1->type = TensorType_FLOAT32;
201 
202   auto tensor_2 = absl::make_unique<TensorT>();
203   tensor_2->quantization = absl::make_unique<QuantizationParametersT>();
204   tensor_2->quantization->scale.push_back(0.35);
205   tensor_2->quantization->zero_point.push_back(28);
206   tensor_2->name = "tensor_2";
207   tensor_2->shape = {};
208   tensor_2->type = quantization_type;
209 
210   auto tensor_3 = absl::make_unique<TensorT>();
211   tensor_3->quantization = absl::make_unique<QuantizationParametersT>();
212   tensor_3->quantization->scale.push_back(0.12);
213   tensor_3->quantization->zero_point.push_back(50);
214   tensor_3->name = "tensor_3";
215   tensor_3->shape = {};
216   tensor_3->type = quantization_type;
217 
218   auto tensor_4 = absl::make_unique<TensorT>();
219   tensor_4->quantization = absl::make_unique<QuantizationParametersT>();
220   tensor_4->quantization->scale.push_back(0.45);
221   tensor_4->quantization->zero_point.push_back(28);
222   tensor_4->name = "tensor_4";
223   tensor_4->shape = {};
224   tensor_4->type = quantization_type;
225 
226   auto tensor_5 = absl::make_unique<TensorT>();
227   tensor_5->quantization = absl::make_unique<QuantizationParametersT>();
228   tensor_5->quantization->scale.push_back(0.22);
229   tensor_5->quantization->zero_point.push_back(50);
230   tensor_5->name = "tensor_5";
231   tensor_5->shape = {};
232   tensor_5->type = quantization_type;
233 
234   auto tensor_6 = absl::make_unique<TensorT>();
235   tensor_6->name = "tensor_6";
236   tensor_6->shape = {};
237   tensor_6->type = TensorType_FLOAT32;
238 
239   auto tensor_7 = absl::make_unique<TensorT>();
240   tensor_7->name = "tensor_7";
241   tensor_7->shape = {};
242   tensor_7->type = TensorType_FLOAT32;
243 
244   model->subgraphs[0]->tensors.push_back(std::move(tensor_0));
245   model->subgraphs[0]->tensors.push_back(std::move(tensor_1));
246   model->subgraphs[0]->tensors.push_back(std::move(tensor_2));
247   model->subgraphs[0]->tensors.push_back(std::move(tensor_3));
248   model->subgraphs[0]->tensors.push_back(std::move(tensor_4));
249   model->subgraphs[0]->tensors.push_back(std::move(tensor_5));
250   model->subgraphs[0]->tensors.push_back(std::move(tensor_6));
251   model->subgraphs[0]->tensors.push_back(std::move(tensor_7));
252 
253   // Buffer
254   model->buffers.push_back(std::move(buffer));
255 
256   return model;
257 }
258 
259 // Create a model with 1 FC.
CreateFloatModel()260 std::unique_ptr<ModelT> CreateFloatModel() {
261   auto model = absl::make_unique<ModelT>();
262   auto subgraph = absl::make_unique<tflite::SubGraphT>();
263   auto buffer = absl::make_unique<tflite::BufferT>();
264   auto fc_op_code = absl::make_unique<OperatorCodeT>();
265   auto fc_op = absl::make_unique<OperatorT>();
266 
267   model->subgraphs.push_back(std::move(subgraph));
268 
269   // Op code
270   fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
271   fc_op_code->deprecated_builtin_code =
272       static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
273   fc_op_code->version = 2;
274 
275   // Op.
276   fc_op->opcode_index = 0;
277   fc_op->inputs = {0};
278   fc_op->outputs = {1};
279 
280   model->subgraphs[0]->operators.push_back(std::move(fc_op));
281   model->operator_codes.push_back(std::move(fc_op_code));
282 
283   // Model input/output.
284   model->subgraphs[0]->inputs = {0};
285   model->subgraphs[0]->outputs = {1};
286 
287   // Tensors
288   auto tensor_0 = absl::make_unique<TensorT>();
289   tensor_0->name = "tensor_0";
290   tensor_0->shape = {};
291   tensor_0->type = TensorType_FLOAT32;
292 
293   auto tensor_1 = absl::make_unique<TensorT>();
294   tensor_1->name = "tensor_1";
295   tensor_1->shape = {};
296   tensor_1->type = TensorType_FLOAT32;
297 
298   model->subgraphs[0]->tensors.push_back(std::move(tensor_0));
299   model->subgraphs[0]->tensors.push_back(std::move(tensor_1));
300 
301   // Buffer
302   model->buffers.push_back(std::move(buffer));
303 
304   return model;
305 }
306 
307 struct ModelInterface : ::testing::TestWithParam<tflite::TensorType> {};
308 
TEST_P(ModelInterface,SingleInputOutput)309 TEST_P(ModelInterface, SingleInputOutput) {
310   TensorType quantization_type = GetParam();
311 
312   auto model = CreateQuantizedModelSingleInputOutput(quantization_type);
313 
314   // Change model type.
315   flatbuffers::FlatBufferBuilder builder;
316   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
317                                  quantization_type),
318             kTfLiteOk);
319 
320   // Verify results.
321   EXPECT_EQ(model->subgraphs.size(), 1);
322   // TODO(mnatraj): The float input tensor has not been removed.
323   // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 2);
324   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
325   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
326   EXPECT_EQ(model->subgraphs[0]->inputs[0], 1);
327   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
328   EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
329   EXPECT_EQ(model->operator_codes.size(), 3);
330   EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
331   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
332 
333   auto fc_op = model->subgraphs[0]->operators[0].get();
334 
335   auto input = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
336   EXPECT_EQ(input->name, "tensor_1");
337   EXPECT_EQ(input->type, quantization_type);
338   EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
339   EXPECT_EQ(input->quantization->zero_point[0], 28);
340 
341   auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
342   EXPECT_EQ(output->name, "tensor_2");
343   EXPECT_EQ(output->type, quantization_type);
344   EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
345   EXPECT_EQ(output->quantization->zero_point[0], 50);
346 }
347 
TEST_P(ModelInterface,MutipleInputOutput)348 TEST_P(ModelInterface, MutipleInputOutput) {
349   TensorType quantization_type = GetParam();
350 
351   auto model = CreateQuantizedModelMultipleInputOutput(quantization_type);
352 
353   // Change model type.
354   flatbuffers::FlatBufferBuilder builder;
355   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), quantization_type,
356                                  quantization_type),
357             kTfLiteOk);
358 
359   // Verify results.
360   EXPECT_EQ(model->subgraphs.size(), 1);
361   // TODO (b/158254056): Remove unused inputs and outputs from tensor list
362   // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
363   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
364   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2);
365   EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
366   EXPECT_EQ(model->subgraphs[0]->inputs[1], 3);
367   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2);
368   EXPECT_EQ(model->subgraphs[0]->outputs[0], 4);
369   EXPECT_EQ(model->subgraphs[0]->outputs[1], 5);
370   EXPECT_EQ(model->operator_codes.size(), 3);
371   EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
372   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
373 
374   auto fc_op = model->subgraphs[0]->operators[0].get();
375 
376   auto input_1 = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
377   EXPECT_EQ(input_1->name, "tensor_2");
378   EXPECT_EQ(input_1->type, quantization_type);
379   EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35);
380   EXPECT_EQ(input_1->quantization->zero_point[0], 28);
381 
382   auto input_2 = model->subgraphs[0]->tensors[fc_op->inputs[1]].get();
383   EXPECT_EQ(input_2->name, "tensor_3");
384   EXPECT_EQ(input_2->type, quantization_type);
385   EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12);
386   EXPECT_EQ(input_2->quantization->zero_point[0], 50);
387 
388   auto output_1 = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
389   EXPECT_EQ(output_1->name, "tensor_4");
390   EXPECT_EQ(output_1->type, quantization_type);
391   EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45);
392   EXPECT_EQ(output_1->quantization->zero_point[0], 28);
393 
394   auto output_2 = model->subgraphs[0]->tensors[fc_op->outputs[1]].get();
395   EXPECT_EQ(output_2->name, "tensor_5");
396   EXPECT_EQ(output_2->type, quantization_type);
397   EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22);
398   EXPECT_EQ(output_2->quantization->zero_point[0], 50);
399 }
400 
401 INSTANTIATE_TEST_SUITE_P(MultipleInputOutputTests, ModelInterface,
402                          ::testing::Values(TensorType_INT8, TensorType_INT16));
403 
TEST(ModelInterface,MixedTypeSingleInputOutput)404 TEST(ModelInterface, MixedTypeSingleInputOutput) {
405   auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
406 
407   // Change model type.
408   flatbuffers::FlatBufferBuilder builder;
409   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
410                                  TensorType_INT8),
411             kTfLiteOk);
412 
413   // Verify results.
414   EXPECT_EQ(model->subgraphs.size(), 1);
415   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
416   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
417   EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
418   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
419   EXPECT_EQ(model->subgraphs[0]->outputs[0], 2);
420   EXPECT_EQ(model->operator_codes.size(), 3);
421   EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
422   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
423   EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
424 
425   auto quant_op = model->subgraphs[0]->operators[0].get();
426   auto input = model->subgraphs[0]->tensors[quant_op->inputs[0]].get();
427   EXPECT_EQ(input->name, "tensor_0");
428   EXPECT_EQ(input->type, TensorType_UINT8);
429   EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
430   EXPECT_EQ(input->quantization->zero_point[0], 156);
431 
432   auto fc_op = model->subgraphs[0]->operators[1].get();
433   auto output = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
434   EXPECT_EQ(output->name, "tensor_2");
435   EXPECT_EQ(output->type, TensorType_INT8);
436   EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
437   EXPECT_EQ(output->quantization->zero_point[0], 50);
438 }
439 
TEST(ModelInterface,Uint8SingleInputOutput)440 TEST(ModelInterface, Uint8SingleInputOutput) {
441   auto model = CreateQuantizedModelSingleInputOutput(TensorType_INT8);
442 
443   // Change model type.
444   flatbuffers::FlatBufferBuilder builder;
445   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
446                                  TensorType_UINT8),
447             kTfLiteOk);
448 
449   // Verify results.
450   EXPECT_EQ(model->subgraphs.size(), 1);
451   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
452   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
453   EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
454   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
455   EXPECT_EQ(model->subgraphs[0]->outputs[0], 3);
456   EXPECT_EQ(model->operator_codes.size(), 3);
457   EXPECT_EQ(model->subgraphs[0]->operators.size(), 3);
458   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
459   EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 1);
460   EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 0);
461 
462   auto input_quant_op = model->subgraphs[0]->operators[0].get();
463   auto input = model->subgraphs[0]->tensors[input_quant_op->inputs[0]].get();
464   EXPECT_EQ(input->name, "tensor_0");
465   EXPECT_EQ(input->type, TensorType_UINT8);
466   EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.35);
467   EXPECT_EQ(input->quantization->zero_point[0], 156);
468 
469   auto output_quant_op = model->subgraphs[0]->operators[2].get();
470   auto output = model->subgraphs[0]->tensors[output_quant_op->outputs[0]].get();
471   EXPECT_EQ(output->name, "tensor_3");
472   EXPECT_EQ(output->type, TensorType_UINT8);
473   EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.12);
474   EXPECT_EQ(output->quantization->zero_point[0], 178);
475 }
476 
TEST(ModelInterface,Uint8MutipleInputOutput)477 TEST(ModelInterface, Uint8MutipleInputOutput) {
478   auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
479 
480   // Change model type.
481   flatbuffers::FlatBufferBuilder builder;
482   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
483                                  TensorType_UINT8),
484             kTfLiteOk);
485 
486   // Verify results.
487   EXPECT_EQ(model->subgraphs.size(), 1);
488   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 8);
489   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2);
490   EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
491   EXPECT_EQ(model->subgraphs[0]->inputs[1], 1);
492   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2);
493   EXPECT_EQ(model->subgraphs[0]->outputs[0], 6);
494   EXPECT_EQ(model->subgraphs[0]->outputs[1], 7);
495   EXPECT_EQ(model->operator_codes.size(), 3);
496   EXPECT_EQ(model->subgraphs[0]->operators.size(), 5);
497   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 0);
498   EXPECT_EQ(model->subgraphs[0]->operators[1]->opcode_index, 0);
499   EXPECT_EQ(model->subgraphs[0]->operators[2]->opcode_index, 1);
500   EXPECT_EQ(model->subgraphs[0]->operators[3]->opcode_index, 0);
501   EXPECT_EQ(model->subgraphs[0]->operators[4]->opcode_index, 0);
502 
503   auto input_quant_1 = model->subgraphs[0]->operators[0].get();
504   auto input_1 = model->subgraphs[0]->tensors[input_quant_1->inputs[0]].get();
505   EXPECT_EQ(input_1->name, "tensor_0");
506   EXPECT_EQ(input_1->type, TensorType_UINT8);
507   EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35);
508   EXPECT_EQ(input_1->quantization->zero_point[0], 156);
509 
510   auto input_quant_2 = model->subgraphs[0]->operators[1].get();
511   auto input_2 = model->subgraphs[0]->tensors[input_quant_2->inputs[0]].get();
512   EXPECT_EQ(input_2->name, "tensor_1");
513   EXPECT_EQ(input_2->type, TensorType_UINT8);
514   EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12);
515   EXPECT_EQ(input_2->quantization->zero_point[0], 178);
516 
517   auto output_quant_1 = model->subgraphs[0]->operators[3].get();
518   auto output_1 =
519       model->subgraphs[0]->tensors[output_quant_1->outputs[0]].get();
520   EXPECT_EQ(output_1->name, "tensor_6");
521   EXPECT_EQ(output_1->type, TensorType_UINT8);
522   EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45);
523   EXPECT_EQ(output_1->quantization->zero_point[0], 156);
524 
525   auto output_quant_2 = model->subgraphs[0]->operators[4].get();
526   auto output_2 =
527       model->subgraphs[0]->tensors[output_quant_2->outputs[0]].get();
528   EXPECT_EQ(output_2->name, "tensor_7");
529   EXPECT_EQ(output_2->type, TensorType_UINT8);
530   EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22);
531   EXPECT_EQ(output_2->quantization->zero_point[0], 178);
532 }
533 
TEST(ModelInterface,Int8MutipleInputOutput)534 TEST(ModelInterface, Int8MutipleInputOutput) {
535   auto model = CreateQuantizedModelMultipleInputOutput(TensorType_INT8);
536 
537   // Change model type.
538   flatbuffers::FlatBufferBuilder builder;
539   EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_INT8,
540                                  TensorType_INT8),
541             kTfLiteOk);
542 
543   // Verify results.
544   EXPECT_EQ(model->subgraphs.size(), 1);
545   // TODO(mnatraj): The two float input tensors have not been removed.
546   // EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
547   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
548   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 2);
549   EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
550   EXPECT_EQ(model->subgraphs[0]->inputs[1], 3);
551   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 2);
552   EXPECT_EQ(model->subgraphs[0]->outputs[0], 4);
553   EXPECT_EQ(model->subgraphs[0]->outputs[1], 5);
554   EXPECT_EQ(model->operator_codes.size(), 3);
555   EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
556   EXPECT_EQ(model->subgraphs[0]->operators[0]->opcode_index, 1);
557 
558   auto fc_op = model->subgraphs[0]->operators[0].get();
559 
560   auto input_1 = model->subgraphs[0]->tensors[fc_op->inputs[0]].get();
561   EXPECT_EQ(input_1->name, "tensor_2");
562   EXPECT_EQ(input_1->type, TensorType_INT8);
563   EXPECT_FLOAT_EQ(input_1->quantization->scale[0], 0.35);
564   EXPECT_EQ(input_1->quantization->zero_point[0], 28);
565 
566   auto input_2 = model->subgraphs[0]->tensors[fc_op->inputs[1]].get();
567   EXPECT_EQ(input_2->name, "tensor_3");
568   EXPECT_EQ(input_2->type, TensorType_INT8);
569   EXPECT_FLOAT_EQ(input_2->quantization->scale[0], 0.12);
570   EXPECT_EQ(input_2->quantization->zero_point[0], 50);
571 
572   auto output_1 = model->subgraphs[0]->tensors[fc_op->outputs[0]].get();
573   EXPECT_EQ(output_1->name, "tensor_4");
574   EXPECT_EQ(output_1->type, TensorType_INT8);
575   EXPECT_FLOAT_EQ(output_1->quantization->scale[0], 0.45);
576   EXPECT_EQ(output_1->quantization->zero_point[0], 28);
577 
578   auto output_2 = model->subgraphs[0]->tensors[fc_op->outputs[1]].get();
579   EXPECT_EQ(output_2->name, "tensor_5");
580   EXPECT_EQ(output_2->type, TensorType_INT8);
581   EXPECT_FLOAT_EQ(output_2->quantization->scale[0], 0.22);
582   EXPECT_EQ(output_2->quantization->zero_point[0], 50);
583 }
584 
TEST(ModelInterface,Float)585 TEST(ModelInterface, Float) {
586   // Create the model.
587   std::unique_ptr<ModelT> input_model_t = CreateFloatModel();
588   flatbuffers::FlatBufferBuilder builder_temp;
589   flatbuffers::Offset<Model> output_model_location =
590       Model::Pack(builder_temp, input_model_t.get());
591   FinishModelBuffer(builder_temp, output_model_location);
592   const uint8_t* buffer_temp = builder_temp.GetBufferPointer();
593   const Model* input_model = GetModel(buffer_temp);
594 
595   // Change model type.
596   flatbuffers::FlatBufferBuilder builder;
597   EXPECT_EQ(Uint8QuantizeModelInputsOutputs(&builder, input_model,
598                                             {{"tensor_0", {0.4, 2}}},
599                                             {{"tensor_1", {0.5, -5}}}),
600             kTfLiteOk);
601 
602   const uint8_t* buffer = builder.GetBufferPointer();
603   const Model* output_model = GetModel(buffer);
604   std::unique_ptr<ModelT> model;
605   model.reset(output_model->UnPack());
606 
607   // Verify results.
608   EXPECT_EQ(model->subgraphs.size(), 1);
609   EXPECT_EQ(model->subgraphs[0]->tensors.size(), 4);
610   EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
611   EXPECT_EQ(model->subgraphs[0]->inputs[0], 0);
612   EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
613   EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
614   EXPECT_EQ(model->operator_codes.size(), 3);
615   EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()),
616             BuiltinOperator_FULLY_CONNECTED);
617   EXPECT_EQ(GetBuiltinCode(model->operator_codes[1].get()),
618             BuiltinOperator_DEQUANTIZE);
619   EXPECT_EQ(GetBuiltinCode(model->operator_codes[2].get()),
620             BuiltinOperator_QUANTIZE);
621   EXPECT_EQ(model->subgraphs[0]->operators.size(), 3);
622 
623   auto dequantize_op = model->subgraphs[0]->operators[0].get();
624   auto input = model->subgraphs[0]->tensors[dequantize_op->inputs[0]].get();
625   EXPECT_EQ(input->name, "tensor_0_uint8");
626   EXPECT_EQ(input->type, TensorType_UINT8);
627   EXPECT_FLOAT_EQ(input->quantization->scale[0], 0.4);
628   EXPECT_EQ(input->quantization->zero_point[0], 2);
629 
630   auto quantize_op = model->subgraphs[0]->operators[2].get();
631   auto output = model->subgraphs[0]->tensors[quantize_op->outputs[0]].get();
632   EXPECT_EQ(output->name, "tensor_1_uint8");
633   EXPECT_EQ(output->type, TensorType_UINT8);
634   EXPECT_FLOAT_EQ(output->quantization->scale[0], 0.5);
635   EXPECT_EQ(output->quantization->zero_point[0], -5);
636 }
637 
638 }  // namespace
639 }  // namespace optimize
640 }  // namespace tflite
641 
main(int argc,char ** argv)642 int main(int argc, char** argv) {
643   ::testing::InitGoogleTest(&argc, argv);
644   return RUN_ALL_TESTS();
645 }
646