1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <vector>
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/test_util.h"
23
24 namespace tflite {
25 namespace {
26
27 using ::testing::ElementsAreArray;
28
TEST(PerChannelDequantize,TestInt8ToFloat_2D)29 TEST(PerChannelDequantize, TestInt8ToFloat_2D) {
30 const std::vector<float> scales = {0.5, 0.25};
31 const std::vector<int> zero_points = {-1, -1};
32 const int quantized_dimension = 0;
33
34 const RuntimeShape shape({2, 5});
35
36 const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
37 123, 124, 125, 126, 127};
38 std::vector<float> output(10, -1);
39
40 PerChannelDequantizationParams op_params;
41 op_params.zero_point = zero_points.data();
42 op_params.scale = scales.data();
43 op_params.quantized_dimension = quantized_dimension;
44 reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
45 output.data());
46 EXPECT_THAT(output,
47 ElementsAreArray(ArrayFloatNear({-63.5, -63, -62.5, -62, -61.5,
48 31, 31.25, 31.5, 31.75, 32})));
49 }
50
TEST(PerChannelDequantize,TestInt8ToFloat_3D)51 TEST(PerChannelDequantize, TestInt8ToFloat_3D) {
52 const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
53 const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
54 const int quantized_dimension = 2;
55
56 const RuntimeShape shape({1, 2, 5});
57
58 const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
59 123, 124, 125, 126, 127};
60 std::vector<float> output(10, -1);
61
62 PerChannelDequantizationParams op_params;
63 op_params.zero_point = zero_points.data();
64 op_params.scale = scales.data();
65 op_params.quantized_dimension = quantized_dimension;
66 reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
67 output.data());
68 EXPECT_THAT(output,
69 ElementsAreArray(ArrayFloatNear({-63.5, -32, -62.5, -31.5, -124,
70 62, 30.75, 63, 31.25, 127})));
71 }
72
TEST(PerChannelDequantize,TestInt8ToFloat_4DDim0)73 TEST(PerChannelDequantize, TestInt8ToFloat_4DDim0) {
74 const std::vector<float> scales = {0.5, 0.25};
75 const std::vector<int> zero_points = {-1, 1};
76 const int quantized_dimension = 0;
77
78 RuntimeShape shape({2, 2, 5, 1});
79
80 const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123, 124,
81 125, 126, 127, -128, -127, -126, -125,
82 -124, 123, 124, 125, 126, 127};
83 std::vector<float> output(20, -1);
84
85 PerChannelDequantizationParams op_params;
86 op_params.zero_point = zero_points.data();
87 op_params.scale = scales.data();
88 op_params.quantized_dimension = quantized_dimension;
89 reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
90 output.data());
91 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
92 {-63.5, -63, -62.5, -62, -61.5, 62, 62.5,
93 63, 63.5, 64, -32.25, -32, -31.75, -31.5,
94 -31.25, 30.5, 30.75, 31, 31.25, 31.5})));
95 }
96
TEST(PerChannelDequantize,TestInt8ToFloat_4DDim3)97 TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) {
98 const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
99 const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
100 const int quantized_dimension = 3;
101
102 RuntimeShape shape({1, 2, 2, 5});
103
104 const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123, 124,
105 125, 126, 127, -128, -127, -126, -125,
106 -124, 123, 124, 125, 126, 127};
107 std::vector<float> output(20, -1);
108
109 PerChannelDequantizationParams op_params;
110 op_params.zero_point = zero_points.data();
111 op_params.scale = scales.data();
112 op_params.quantized_dimension = quantized_dimension;
113 reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
114 output.data());
115 EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
116 {-63.5, -32, -62.5, -31.5, -124, 62, 30.75,
117 63, 31.25, 127, -63.5, -32, -62.5, -31.5,
118 -124, 62, 30.75, 63, 31.25, 127})));
119 }
120
121 } // namespace
122 } // namespace tflite
123