• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <vector>
17 
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/test_util.h"
23 
24 namespace tflite {
25 namespace {
26 
27 using ::testing::ElementsAreArray;
28 
TEST(PerChannelDequantize,TestInt8ToFloat_2D)29 TEST(PerChannelDequantize, TestInt8ToFloat_2D) {
30   const std::vector<float> scales = {0.5, 0.25};
31   const std::vector<int> zero_points = {-1, -1};
32   const int quantized_dimension = 0;
33 
34   const RuntimeShape shape({2, 5});
35 
36   const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
37                                      123,  124,  125,  126,  127};
38   std::vector<float> output(10, -1);
39 
40   PerChannelDequantizationParams op_params;
41   op_params.zero_point = zero_points.data();
42   op_params.scale = scales.data();
43   op_params.quantized_dimension = quantized_dimension;
44   reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
45                                       output.data());
46   EXPECT_THAT(output,
47               ElementsAreArray(ArrayFloatNear({-63.5, -63, -62.5, -62, -61.5,
48                                                31, 31.25, 31.5, 31.75, 32})));
49 }
50 
TEST(PerChannelDequantize,TestInt8ToFloat_3D)51 TEST(PerChannelDequantize, TestInt8ToFloat_3D) {
52   const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
53   const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
54   const int quantized_dimension = 2;
55 
56   const RuntimeShape shape({1, 2, 5});
57 
58   const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
59                                      123,  124,  125,  126,  127};
60   std::vector<float> output(10, -1);
61 
62   PerChannelDequantizationParams op_params;
63   op_params.zero_point = zero_points.data();
64   op_params.scale = scales.data();
65   op_params.quantized_dimension = quantized_dimension;
66   reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
67                                       output.data());
68   EXPECT_THAT(output,
69               ElementsAreArray(ArrayFloatNear({-63.5, -32, -62.5, -31.5, -124,
70                                                62, 30.75, 63, 31.25, 127})));
71 }
72 
TEST(PerChannelDequantize,TestInt8ToFloat_4DDim0)73 TEST(PerChannelDequantize, TestInt8ToFloat_4DDim0) {
74   const std::vector<float> scales = {0.5, 0.25};
75   const std::vector<int> zero_points = {-1, 1};
76   const int quantized_dimension = 0;
77 
78   RuntimeShape shape({2, 2, 5, 1});
79 
80   const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123,  124,
81                                      125,  126,  127,  -128, -127, -126, -125,
82                                      -124, 123,  124,  125,  126,  127};
83   std::vector<float> output(20, -1);
84 
85   PerChannelDequantizationParams op_params;
86   op_params.zero_point = zero_points.data();
87   op_params.scale = scales.data();
88   op_params.quantized_dimension = quantized_dimension;
89   reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
90                                       output.data());
91   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
92                           {-63.5,  -63,  -62.5, -62,    -61.5, 62,     62.5,
93                            63,     63.5, 64,    -32.25, -32,   -31.75, -31.5,
94                            -31.25, 30.5, 30.75, 31,     31.25, 31.5})));
95 }
96 
TEST(PerChannelDequantize,TestInt8ToFloat_4DDim3)97 TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) {
98   const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
99   const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
100   const int quantized_dimension = 3;
101 
102   RuntimeShape shape({1, 2, 2, 5});
103 
104   const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123,  124,
105                                      125,  126,  127,  -128, -127, -126, -125,
106                                      -124, 123,  124,  125,  126,  127};
107   std::vector<float> output(20, -1);
108 
109   PerChannelDequantizationParams op_params;
110   op_params.zero_point = zero_points.data();
111   op_params.scale = scales.data();
112   op_params.quantized_dimension = quantized_dimension;
113   reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
114                                       output.data());
115   EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
116                           {-63.5, -32,   -62.5, -31.5, -124,  62,    30.75,
117                            63,    31.25, 127,   -63.5, -32,   -62.5, -31.5,
118                            -124,  62,    30.75, 63,    31.25, 127})));
119 }
120 
121 }  // namespace
122 }  // namespace tflite
123