• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <cstdlib>
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <random>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25 
26 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
27 
28 #include <gmock/gmock.h>
29 #include <gtest/gtest.h>
30 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
31 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
32 #include "tensorflow/lite/string_type.h"
33 
34 namespace tflite {
35 
36 class NumberGenerator {
37  public:
RandomIntVector(int n,int min_val,int max_val)38   std::vector<int> RandomIntVector(int n, int min_val, int max_val) {
39     std::vector<int> vec(n);
40     double scale = static_cast<double>(max_val + 1 - min_val) / engine_.max();
41     for (auto& it : vec) {
42       it = min_val + std::floor(engine_() * scale);
43     }
44     return vec;
45   }
46 
47   std::mt19937 engine_;
48 };
49 
50 class LogQuantizedTest : public ::testing::Test {
51  public:
52   NumberGenerator generator_;
53 };
54 
55 // input_integer_bits <= 30.  output_integer_bits > 0.
LogPositiveValuesViaFloat(int32 input_val,int input_integer_bits,int output_integer_bits)56 inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits,
57                                        int output_integer_bits) {
58   const double float_log_sum_of_exps = std::log(
59       static_cast<double>(input_val) * 0.5 / (1 << (30 - input_integer_bits)));
60   static constexpr double min_int =
61       static_cast<double>(std::numeric_limits<int32>::min());
62   static constexpr double max_int =
63       static_cast<double>(std::numeric_limits<int32>::max());
64   double double_result = tflite::TfLiteRound(float_log_sum_of_exps *
65                                              (1 << (31 - output_integer_bits)));
66   return static_cast<std::int32_t>(
67       std::min(max_int, std::max(min_int, double_result)));
68 }
69 
CheckOutputData(const std::vector<int32> & test_output,const std::vector<int32> & reference_output,const std::vector<int32> & test_input,const string & check_label,int input_integer_bits,int output_integer_bits,int tolerance)70 void CheckOutputData(const std::vector<int32>& test_output,
71                      const std::vector<int32>& reference_output,
72                      const std::vector<int32>& test_input,
73                      const string& check_label, int input_integer_bits,
74                      int output_integer_bits, int tolerance) {
75   // In the special case of small input, specifically raw value of 5, a rounding
76   // up leads to difference in the output.  We do not aim to be accurate for
77   // very small input values, and there should be sufficient input fractional
78   // bits that this is a small input.
79   static constexpr double error_from_rounding_up = 0.0224585;
80   const int n = test_output.size();
81   ASSERT_EQ(n, reference_output.size());
82   for (int i = 0; i < n; ++i) {
83     // Adjust tolerance when input <= 5*2^-(31-input_integer_bits).
84     const int adjusted_tolerance =
85         test_input[i] > 5
86             ? tolerance
87             : std::max(tolerance, static_cast<int>(std::ceil(
88                                       error_from_rounding_up *
89                                       (1 << (31 - output_integer_bits)))));
90     ASSERT_LE(std::abs(test_output[i] - reference_output[i]),
91               adjusted_tolerance)
92         << "Failure in \"" << check_label << "\" at i=" << i
93         << ", test_input[i]=" << test_input[i] << "="
94         << static_cast<double>(test_input[i]) / (1 << (31 - input_integer_bits))
95         << ", test_output[i]=" << test_output[i] << "="
96         << static_cast<double>(test_output[i]) /
97                (1 << (31 - output_integer_bits))
98         << ", reference_output[i]=" << reference_output[i] << "="
99         << static_cast<double>(reference_output[i]) /
100                (1 << (31 - output_integer_bits))
101         << ", difference[i]=" << std::abs(reference_output[i] - test_output[i])
102         << "="
103         << static_cast<double>(std::abs(reference_output[i] - test_output[i])) /
104                (1 << (31 - output_integer_bits))
105         << "; tolerance=" << tolerance
106         << ", adj tolerance=" << adjusted_tolerance;
107   }
108 }
109 
RightShiftVector(const std::vector<int32> & shifts,std::vector<int32> * vec)110 void RightShiftVector(const std::vector<int32>& shifts,
111                       std::vector<int32>* vec) {
112   const int n = vec->size();
113   ASSERT_EQ(n, shifts.size());
114   for (int i = 0; i < n; ++i) {
115     vec->at(i) = std::max(1, vec->at(i) >> shifts[i]);
116   }
117 }
118 
119 template <int OutputIntegerBits, int InputIntegerBits>
RunSingleTest(const std::vector<int32> & test_input,const string & check_label,int tolerance)120 void RunSingleTest(const std::vector<int32>& test_input,
121                    const string& check_label, int tolerance) {
122   const int n = test_input.size();
123   std::vector<int32> float_gen_output(n, 0);
124   std::vector<int32> quantized_output(n, 0);
125 
126   // Workaround the stupid things that intelligent humans do.
127   // Consequence of __builtin_clz(0u) may equal 31 instead of 32.
128   std::vector<int32> fudged_input(n, 0);
129   for (int i = 0; i < n; ++i) {
130     fudged_input[i] = std::max(test_input[i], 2);
131   }
132 
133   for (int i = 0; i < n; ++i) {
134     quantized_output[i] =
135         tflite::log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
136                                                             InputIntegerBits>(
137             gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
138                 fudged_input[i]))
139             .raw();
140     float_gen_output[i] = LogPositiveValuesViaFloat(
141         fudged_input[i], InputIntegerBits, OutputIntegerBits);
142   }
143   {
144     std::ostringstream label;
145     label << check_label
146           << " / reference vs float-gen / InputIntegerBits=" << InputIntegerBits
147           << ", OutputIntegerBits=" << OutputIntegerBits;
148     CheckOutputData(quantized_output, float_gen_output, test_input, label.str(),
149                     InputIntegerBits, OutputIntegerBits, tolerance);
150   }
151 }
152 
153 template <int OutputIntegerBits>
RunSingleTest(const std::vector<int32> & test_input,int input_integer_bits,const string & check_label,int tolerance)154 void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
155                    const string& check_label, int tolerance) {
156 #define INPUT_CASE(K)                                                   \
157   case K:                                                               \
158     return RunSingleTest<OutputIntegerBits, K>(test_input, check_label, \
159                                                tolerance)
160   switch (input_integer_bits) {
161     INPUT_CASE(0);
162     INPUT_CASE(1);
163     INPUT_CASE(2);
164     INPUT_CASE(3);
165     INPUT_CASE(4);
166     INPUT_CASE(5);
167     INPUT_CASE(6);
168     INPUT_CASE(7);
169     INPUT_CASE(8);
170     INPUT_CASE(9);
171     INPUT_CASE(10);
172     INPUT_CASE(11);
173     INPUT_CASE(12);
174     INPUT_CASE(13);
175     INPUT_CASE(14);
176     INPUT_CASE(15);
177     INPUT_CASE(16);
178     INPUT_CASE(17);
179     INPUT_CASE(18);
180     INPUT_CASE(19);
181     INPUT_CASE(20);
182     INPUT_CASE(21);
183     INPUT_CASE(22);
184     INPUT_CASE(23);
185     INPUT_CASE(24);
186     INPUT_CASE(25);
187     INPUT_CASE(26);
188     INPUT_CASE(27);
189     INPUT_CASE(28);
190     INPUT_CASE(29);
191     default:
192       ASSERT_LE(input_integer_bits, 30)
193           << "Input integer bits not handled: " << input_integer_bits;
194   }
195 #undef INPUT_CASE
196 }
197 
RunSingleTest(const std::vector<int32> & test_input,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance)198 void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
199                    int output_integer_bits, const string& check_label,
200                    int tolerance) {
201 #define OUTPUT_CASE(K)                                                   \
202   case K:                                                                \
203     return RunSingleTest<K>(test_input, input_integer_bits, check_label, \
204                             tolerance)
205   switch (output_integer_bits) {
206     OUTPUT_CASE(0);
207     OUTPUT_CASE(1);
208     OUTPUT_CASE(2);
209     OUTPUT_CASE(3);
210     OUTPUT_CASE(4);
211     OUTPUT_CASE(5);
212     OUTPUT_CASE(6);
213     OUTPUT_CASE(7);
214     OUTPUT_CASE(8);
215     OUTPUT_CASE(9);
216     OUTPUT_CASE(10);
217     OUTPUT_CASE(11);
218     OUTPUT_CASE(12);
219     OUTPUT_CASE(13);
220     OUTPUT_CASE(14);
221     OUTPUT_CASE(15);
222     OUTPUT_CASE(16);
223     OUTPUT_CASE(17);
224     OUTPUT_CASE(18);
225     OUTPUT_CASE(19);
226     OUTPUT_CASE(20);
227     OUTPUT_CASE(21);
228     OUTPUT_CASE(22);
229     OUTPUT_CASE(23);
230     OUTPUT_CASE(24);
231     OUTPUT_CASE(25);
232     OUTPUT_CASE(26);
233     OUTPUT_CASE(27);
234     OUTPUT_CASE(28);
235     OUTPUT_CASE(29);
236     default:
237       ASSERT_LE(input_integer_bits, 30)
238           << "Input integer bits not handled: " << input_integer_bits;
239   }
240 #undef OUTPUT_CASE
241 }
242 
RunUniformTest(int test_size,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance,NumberGenerator * generator)243 void RunUniformTest(int test_size, int input_integer_bits,
244                     int output_integer_bits, const string& check_label,
245                     int tolerance, NumberGenerator* generator) {
246   std::vector<int> test_data = generator->RandomIntVector(
247       test_size, 2, std::numeric_limits<int>::max() - 1);
248   test_data[0] = 2;
249   test_data[1] = 3;
250   test_data[2] = 4;
251   test_data[3] = std::numeric_limits<int32>::max() - 2;
252   test_data[4] = std::numeric_limits<int32>::max() - 1;
253   test_data[5] = std::numeric_limits<int32>::max();
254 
255   RunSingleTest(test_data, input_integer_bits, output_integer_bits,
256                 check_label + " / uniform test", tolerance);
257 }
258 
RunUniformShiftUniformTest(int test_size,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance,NumberGenerator * generator)259 void RunUniformShiftUniformTest(int test_size, int input_integer_bits,
260                                 int output_integer_bits,
261                                 const string& check_label, int tolerance,
262                                 NumberGenerator* generator) {
263   std::vector<int> test_data = generator->RandomIntVector(
264       test_size, 2, std::numeric_limits<int>::max() - 1);
265   std::vector<int> shifts = generator->RandomIntVector(test_size, 0, 29);
266   RightShiftVector(shifts, &test_data);
267 
268   RunSingleTest(test_data, input_integer_bits, output_integer_bits,
269                 check_label + " / shifted test", tolerance);
270 }
271 
TEST_F(LogQuantizedTest,VariedIntegerBits)272 TEST_F(LogQuantizedTest, VariedIntegerBits) {
273   static constexpr int kVariations = 250;
274   static constexpr int kRunSize = 250;
275   static constexpr int kIntegerTolerance = 8;
276   static constexpr double kOutputFloatTolerance = 7.0e-7;
277 
278   std::vector<int> input_integer_bits =
279       generator_.RandomIntVector(kVariations, 0, 24);
280   std::vector<int> output_integer_bits =
281       generator_.RandomIntVector(kVariations, 1, 10);
282 
283   for (int i = 0; i < kVariations; ++i) {
284     int var_output_integer_bits = output_integer_bits[i];
285     int tolerance =
286         std::max(1.0 * kIntegerTolerance,
287                  (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance);
288 
289     RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits,
290                    "VariedIntegerBits", tolerance, &generator_);
291     RunUniformShiftUniformTest(kRunSize, input_integer_bits[i],
292                                var_output_integer_bits, "VariedIntegerBits",
293                                tolerance, &generator_);
294   }
295 }
296 
TEST_F(LogQuantizedTest,SelectedIntegerBits)297 TEST_F(LogQuantizedTest, SelectedIntegerBits) {
298   static constexpr int kInputBits = 12;
299   static constexpr int kOutputBits = 5;
300   static constexpr int kRunSize = 100000;
301   static constexpr int kIntegerTolerance = 4;
302 
303   RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits",
304                  kIntegerTolerance, &generator_);
305   RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits,
306                              "SelectedIntegerBits", kIntegerTolerance,
307                              &generator_);
308 }
309 
310 }  // namespace tflite
311