• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <cstdlib>
18 #include <functional>
19 #include <iterator>
20 #include <limits>
21 #include <random>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25 
26 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
27 
28 #include <gmock/gmock.h>
29 #include <gtest/gtest.h>
30 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
31 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
32 #include "tensorflow/lite/string.h"
33 
34 namespace tflite {
35 
36 class NumberGenerator {
37  public:
RandomIntVector(int n,int min_val,int max_val)38   std::vector<int> RandomIntVector(int n, int min_val, int max_val) {
39     std::vector<int> vec(n);
40     double scale = static_cast<double>(max_val + 1 - min_val) / engine_.max();
41     for (auto& it : vec) {
42       it = min_val + std::floor(engine_() * scale);
43     }
44     return vec;
45   }
46 
47   std::mt19937 engine_;
48 };
49 
50 class LogQuantizedTest : public ::testing::Test {
51  public:
52   NumberGenerator generator_;
53 };
54 
55 // input_integer_bits <= 30.  output_integer_bits > 0.
LogPositiveValuesViaFloat(int32 input_val,int input_integer_bits,int output_integer_bits)56 inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits,
57                                        int output_integer_bits) {
58   const double float_log_sum_of_exps = std::log(
59       static_cast<double>(input_val) * 0.5 / (1 << (30 - input_integer_bits)));
60   static constexpr double min_int =
61       static_cast<double>(std::numeric_limits<int32>::min());
62   static constexpr double max_int =
63       static_cast<double>(std::numeric_limits<int32>::max());
64   double double_result = tflite::TfLiteRound(float_log_sum_of_exps *
65                                              (1 << (31 - output_integer_bits)));
66   return static_cast<std::int32_t>(
67       std::min(max_int, std::max(min_int, double_result)));
68 }
69 
CheckOutputData(const std::vector<int32> & test_output,const std::vector<int32> & reference_output,const std::vector<int32> & test_input,const string & check_label,int input_integer_bits,int output_integer_bits,int tolerance)70 void CheckOutputData(const std::vector<int32>& test_output,
71                      const std::vector<int32>& reference_output,
72                      const std::vector<int32>& test_input,
73                      const string& check_label, int input_integer_bits,
74                      int output_integer_bits, int tolerance) {
75   // In the special case of small input, specifically raw value of 5, a rounding
76   // up leads to difference in the output.  We do not aim to be accurate for
77   // very small input values, and there should be sufficient input fractional
78   // bits that this is a small input.
79   static constexpr double error_from_rounding_up = 0.0224585;
80   const int n = test_output.size();
81   ASSERT_EQ(n, reference_output.size());
82   for (int i = 0; i < n; ++i) {
83     // Adjust tolerance when input <= 5*2^-(31-input_integer_bits).
84     const int adjusted_tolerance =
85         test_input[i] > 5
86             ? tolerance
87             : std::max(tolerance, static_cast<int>(std::ceil(
88                                       error_from_rounding_up *
89                                       (1 << (31 - output_integer_bits)))));
90     ASSERT_LE(std::abs(test_output[i] - reference_output[i]),
91               adjusted_tolerance)
92         << "Failure in \"" << check_label << "\" at i=" << i
93         << ", test_input[i]=" << test_input[i] << "="
94         << static_cast<double>(test_input[i]) / (1 << (31 - input_integer_bits))
95         << ", test_output[i]=" << test_output[i] << "="
96         << static_cast<double>(test_output[i]) /
97                (1 << (31 - output_integer_bits))
98         << ", reference_output[i]=" << reference_output[i] << "="
99         << static_cast<double>(reference_output[i]) /
100                (1 << (31 - output_integer_bits))
101         << ", difference[i]=" << std::abs(reference_output[i] - test_output[i])
102         << "="
103         << static_cast<double>(std::abs(reference_output[i] - test_output[i])) /
104                (1 << (31 - output_integer_bits))
105         << "; tolerance=" << tolerance
106         << ", adj tolerance=" << adjusted_tolerance;
107   }
108 }
109 
RightShiftVector(const std::vector<int32> & shifts,std::vector<int32> * vec)110 void RightShiftVector(const std::vector<int32>& shifts,
111                       std::vector<int32>* vec) {
112   const int n = vec->size();
113   ASSERT_EQ(n, shifts.size());
114   for (int i = 0; i < n; ++i) {
115     vec->at(i) = std::max(1, vec->at(i) >> shifts[i]);
116   }
117 }
118 
119 template <int OutputIntegerBits, int InputIntegerBits>
RunSingleTest(const std::vector<int32> & test_input,const string & check_label,int tolerance)120 void RunSingleTest(const std::vector<int32>& test_input,
121                    const string& check_label, int tolerance) {
122   const int n = test_input.size();
123   std::vector<int32> float_gen_output(n, 0);
124   std::vector<int32> quantized_output(n, 0);
125 
126   // Workaround the stupid things that intelligent humans do.
127   // Consequence of __builtin_clz(0u) may equal 31 instead of 32.
128   std::vector<int32> fudged_input(n, 0);
129   for (int i = 0; i < n; ++i) {
130     fudged_input[i] = std::max(test_input[i], 2);
131   }
132 
133   for (int i = 0; i < n; ++i) {
134     quantized_output[i] =
135         tflite::log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
136                                                             InputIntegerBits>(
137             gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
138                 fudged_input[i]))
139             .raw();
140     float_gen_output[i] = LogPositiveValuesViaFloat(
141         fudged_input[i], InputIntegerBits, OutputIntegerBits);
142   }
143   {
144     std::ostringstream label;
145     label << check_label << " / reference vs float-gen / InputIntegerBits="
146           << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
147     CheckOutputData(quantized_output, float_gen_output, test_input, label.str(),
148                     InputIntegerBits, OutputIntegerBits, tolerance);
149   }
150 }
151 
152 template <int OutputIntegerBits>
RunSingleTest(const std::vector<int32> & test_input,int input_integer_bits,const string & check_label,int tolerance)153 void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
154                    const string& check_label, int tolerance) {
155 #define INPUT_CASE(K)                                                   \
156   case K:                                                               \
157     return RunSingleTest<OutputIntegerBits, K>(test_input, check_label, \
158                                                tolerance)
159   switch (input_integer_bits) {
160     INPUT_CASE(0);
161     INPUT_CASE(1);
162     INPUT_CASE(2);
163     INPUT_CASE(3);
164     INPUT_CASE(4);
165     INPUT_CASE(5);
166     INPUT_CASE(6);
167     INPUT_CASE(7);
168     INPUT_CASE(8);
169     INPUT_CASE(9);
170     INPUT_CASE(10);
171     INPUT_CASE(11);
172     INPUT_CASE(12);
173     INPUT_CASE(13);
174     INPUT_CASE(14);
175     INPUT_CASE(15);
176     INPUT_CASE(16);
177     INPUT_CASE(17);
178     INPUT_CASE(18);
179     INPUT_CASE(19);
180     INPUT_CASE(20);
181     INPUT_CASE(21);
182     INPUT_CASE(22);
183     INPUT_CASE(23);
184     INPUT_CASE(24);
185     INPUT_CASE(25);
186     INPUT_CASE(26);
187     INPUT_CASE(27);
188     INPUT_CASE(28);
189     INPUT_CASE(29);
190     default:
191       ASSERT_LE(input_integer_bits, 30)
192                 << "Input integer bits not handled: " << input_integer_bits;
193   }
194 #undef INPUT_CASE
195 }
196 
RunSingleTest(const std::vector<int32> & test_input,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance)197 void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
198                    int output_integer_bits, const string& check_label,
199                    int tolerance) {
200 #define OUTPUT_CASE(K)                                                   \
201   case K:                                                                \
202     return RunSingleTest<K>(test_input, input_integer_bits, check_label, \
203                             tolerance)
204   switch (output_integer_bits) {
205     OUTPUT_CASE(0);
206     OUTPUT_CASE(1);
207     OUTPUT_CASE(2);
208     OUTPUT_CASE(3);
209     OUTPUT_CASE(4);
210     OUTPUT_CASE(5);
211     OUTPUT_CASE(6);
212     OUTPUT_CASE(7);
213     OUTPUT_CASE(8);
214     OUTPUT_CASE(9);
215     OUTPUT_CASE(10);
216     OUTPUT_CASE(11);
217     OUTPUT_CASE(12);
218     OUTPUT_CASE(13);
219     OUTPUT_CASE(14);
220     OUTPUT_CASE(15);
221     OUTPUT_CASE(16);
222     OUTPUT_CASE(17);
223     OUTPUT_CASE(18);
224     OUTPUT_CASE(19);
225     OUTPUT_CASE(20);
226     OUTPUT_CASE(21);
227     OUTPUT_CASE(22);
228     OUTPUT_CASE(23);
229     OUTPUT_CASE(24);
230     OUTPUT_CASE(25);
231     OUTPUT_CASE(26);
232     OUTPUT_CASE(27);
233     OUTPUT_CASE(28);
234     OUTPUT_CASE(29);
235     default:
236       ASSERT_LE(input_integer_bits, 30)
237                 << "Input integer bits not handled: " << input_integer_bits;
238   }
239 #undef OUTPUT_CASE
240 }
241 
RunUniformTest(int test_size,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance,NumberGenerator * generator)242 void RunUniformTest(int test_size, int input_integer_bits,
243                     int output_integer_bits, const string& check_label,
244                     int tolerance, NumberGenerator* generator) {
245   std::vector<int> test_data = generator->RandomIntVector(
246       test_size, 2, std::numeric_limits<int>::max() - 1);
247   test_data[0] = 2;
248   test_data[1] = 3;
249   test_data[2] = 4;
250   test_data[3] = std::numeric_limits<int32>::max() - 2;
251   test_data[4] = std::numeric_limits<int32>::max() - 1;
252   test_data[5] = std::numeric_limits<int32>::max();
253 
254   RunSingleTest(test_data, input_integer_bits, output_integer_bits,
255                 check_label + " / uniform test", tolerance);
256 }
257 
RunUniformShiftUniformTest(int test_size,int input_integer_bits,int output_integer_bits,const string & check_label,int tolerance,NumberGenerator * generator)258 void RunUniformShiftUniformTest(int test_size, int input_integer_bits,
259                                 int output_integer_bits,
260                                 const string& check_label, int tolerance,
261                                 NumberGenerator* generator) {
262   std::vector<int> test_data = generator->RandomIntVector(
263       test_size, 2, std::numeric_limits<int>::max() - 1);
264   std::vector<int> shifts = generator->RandomIntVector(test_size, 0, 29);
265   RightShiftVector(shifts, &test_data);
266 
267   RunSingleTest(test_data, input_integer_bits, output_integer_bits,
268                 check_label + " / shifted test", tolerance);
269 }
270 
TEST_F(LogQuantizedTest,VariedIntegerBits)271 TEST_F(LogQuantizedTest, VariedIntegerBits) {
272   static constexpr int kVariations = 250;
273   static constexpr int kRunSize = 250;
274   static constexpr int kIntegerTolerance = 8;
275   static constexpr double kOutputFloatTolerance = 7.0e-7;
276 
277   std::vector<int> input_integer_bits =
278       generator_.RandomIntVector(kVariations, 0, 24);
279   std::vector<int> output_integer_bits =
280       generator_.RandomIntVector(kVariations, 1, 10);
281 
282   for (int i = 0; i < kVariations; ++i) {
283     int var_output_integer_bits = output_integer_bits[i];
284     int tolerance =
285         std::max(1.0 * kIntegerTolerance,
286                  (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance);
287 
288     RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits,
289                    "VariedIntegerBits", tolerance, &generator_);
290     RunUniformShiftUniformTest(kRunSize, input_integer_bits[i],
291                                var_output_integer_bits, "VariedIntegerBits",
292                                tolerance, &generator_);
293   }
294 }
295 
TEST_F(LogQuantizedTest,SelectedIntegerBits)296 TEST_F(LogQuantizedTest, SelectedIntegerBits) {
297   static constexpr int kInputBits = 12;
298   static constexpr int kOutputBits = 5;
299   static constexpr int kRunSize = 100000;
300   static constexpr int kIntegerTolerance = 4;
301 
302   RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits",
303                  kIntegerTolerance, &generator_);
304   RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits,
305                              "SelectedIntegerBits", kIntegerTolerance,
306                              &generator_);
307 }
308 
309 }  // namespace tflite
310