• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/quantization_util.h"
16 
17 #include <limits>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/kernels/internal/common.h"
22 
23 namespace tflite {
24 namespace {
25 
26 using ::testing::ElementsAreArray;
27 using ::testing::Pair;
28 
29 template <class FloatIn, class IntOut>
RunSafeCastTests()30 void RunSafeCastTests() {
31   const IntOut imax = std::numeric_limits<IntOut>::max();
32   EXPECT_GT(imax, 0);
33   const IntOut imin = std::numeric_limits<IntOut>::min();
34   const bool s = std::numeric_limits<IntOut>::is_signed;
35   if (s) {
36     EXPECT_LT(imin, 0);
37   } else {
38     EXPECT_EQ(0, imin);
39   }
40 
41   // Some basic tests.
42   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0.0)), 0);
43   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-0.0)), 0);
44   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0.99)), 0);
45   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.0)), 1);
46   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.01)), 1);
47   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.99)), 1);
48   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(2.0)), 2);
49   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(2.01)), 2);
50   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-0.99)), 0);
51   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.0)), s ? -1 : 0);
52   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.01)), s ? -1 : 0);
53   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.99)), s ? -1 : 0);
54   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-2.0)), s ? -2 : 0);
55   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-2.01)), s ? -2 : 0);
56   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(117.9)), 117);
57   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(118.0)), 118);
58   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(118.1)), 118);
59   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-117.9)), s ? -117 : 0);
60   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-118.0)), s ? -118 : 0);
61   EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-118.1)), s ? -118 : 0);
62 
63   // Some edge cases.
64   EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::max()), imax);
65   EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::lowest()), imin);
66   EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::infinity()), imax);
67   EXPECT_EQ(SafeCast<IntOut>(-std::numeric_limits<FloatIn>::infinity()), imin);
68   EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::quiet_NaN()), 0);
69 
70   // Some larger numbers.
71   if (sizeof(IntOut) >= 4 && sizeof(FloatIn) > 4) {
72     EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0x76543210)), 0x76543210);
73   }
74 
75   if (sizeof(FloatIn) > sizeof(IntOut)) {
76     // Check values near imax.
77     EXPECT_EQ(SafeCast<IntOut>(
78                   static_cast<FloatIn>(static_cast<FloatIn>(imax) + 0.1)),
79               imax);
80     EXPECT_EQ(SafeCast<IntOut>(
81                   static_cast<FloatIn>(static_cast<FloatIn>(imax) + 0.99)),
82               imax);
83     EXPECT_EQ(SafeCast<IntOut>(
84                   static_cast<FloatIn>(static_cast<FloatIn>(imax) + 1.0)),
85               imax);
86     EXPECT_EQ(SafeCast<IntOut>(
87                   static_cast<FloatIn>(static_cast<FloatIn>(imax) + 1.99)),
88               imax);
89     EXPECT_EQ(SafeCast<IntOut>(
90                   static_cast<FloatIn>(static_cast<FloatIn>(imax) + 2.0)),
91               imax);
92     EXPECT_EQ(SafeCast<IntOut>(
93                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 0.1)),
94               imax - 1);
95     EXPECT_EQ(SafeCast<IntOut>(
96                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 0.99)),
97               imax - 1);
98     EXPECT_EQ(SafeCast<IntOut>(
99                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.0)),
100               imax - 1);
101     EXPECT_EQ(SafeCast<IntOut>(
102                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.01)),
103               imax - 2);
104     EXPECT_EQ(SafeCast<IntOut>(
105                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.99)),
106               imax - 2);
107     EXPECT_EQ(SafeCast<IntOut>(
108                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 2.0)),
109               imax - 2);
110     EXPECT_EQ(SafeCast<IntOut>(
111                   static_cast<FloatIn>(static_cast<FloatIn>(imax) - 2.01)),
112               imax - 3);
113   }
114 
115   // Check values considerably larger in magnitude than imin and imax
116   EXPECT_EQ(
117       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 2)),
118       imax);
119   EXPECT_EQ(
120       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 20)),
121       imax);
122   EXPECT_EQ(
123       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 100)),
124       imax);
125   EXPECT_EQ(
126       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 2)),
127       imin);
128   EXPECT_EQ(
129       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 20)),
130       imin);
131   EXPECT_EQ(
132       SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 100)),
133       imin);
134 }
135 
TEST(QuantizationUtilTest,SafeCast)136 TEST(QuantizationUtilTest, SafeCast) {
137   RunSafeCastTests<float, int8_t>();
138   RunSafeCastTests<double, int8_t>();
139   RunSafeCastTests<float, int16_t>();
140   RunSafeCastTests<double, int16_t>();
141   RunSafeCastTests<float, int32_t>();
142   RunSafeCastTests<double, int32_t>();
143   RunSafeCastTests<float, int64_t>();
144   RunSafeCastTests<double, int64_t>();
145   RunSafeCastTests<float, uint8_t>();
146   RunSafeCastTests<double, uint8_t>();
147   RunSafeCastTests<float, uint16_t>();
148   RunSafeCastTests<double, uint16_t>();
149   RunSafeCastTests<float, uint32_t>();
150   RunSafeCastTests<double, uint32_t>();
151   RunSafeCastTests<float, uint64_t>();
152   RunSafeCastTests<double, uint64_t>();
153 }
154 
155 // Example taken from http://www.tensorflow.org/performance/quantization
156 //
157 //  Quantized | Float
158 //  --------- | -----
159 //  0         | -10.0
160 //  255       | 30.0
161 //  128       | 10.0
TEST(QuantizationUtilTest,ChooseQuantizationParams)162 TEST(QuantizationUtilTest, ChooseQuantizationParams) {
163   QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0);
164   EXPECT_NEAR(qp.scale, 0.156863, 1e-5);
165   EXPECT_EQ(qp.zero_point, 64);
166 }
167 
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroPointOnMinBoundary)168 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
169   QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0);
170   EXPECT_NEAR(qp.scale, 0.117647, 1e-5);
171   EXPECT_EQ(qp.zero_point, 0);
172 }
173 
174 #ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroNotInRange)175 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
176   // Assumption is that zero is within the range.
177   EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
178 }
179 
TEST(QuantizationUtilTest,ChooseQuantizationParamsEmptyRangePositive)180 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
181   // Assumption is that zero is within the range.
182   EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
183 }
184 #endif  // GTEST_HAS_DEATH_TEST
185 
TEST(QuantizationUtilTest,ChooseQuantizationParamsEmptyRangeZero)186 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
187   QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
188   EXPECT_NEAR(qp.scale, 0.0, 1e-5);
189   EXPECT_EQ(qp.zero_point, 0);
190 }
191 
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroPointOnMaxBoundary)192 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
193   QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0);
194   EXPECT_NEAR(qp.scale, 0.039216, 1e-5);
195   EXPECT_EQ(qp.zero_point, 255);
196 }
197 
TEST(QuantizationUtilTest,IntegerFrExp)198 TEST(QuantizationUtilTest, IntegerFrExp) {
199   int shift;
200   int64_t result = IntegerFrExp(0.0, &shift);
201   EXPECT_EQ(0, result);
202   EXPECT_EQ(0, shift);
203 
204   result = IntegerFrExp(1.0, &shift);
205   EXPECT_NEAR(0x40000000, result, 1);
206   EXPECT_EQ(1, shift);
207 
208   result = IntegerFrExp(0.25, &shift);
209   EXPECT_NEAR(0x40000000, result, 1);
210   EXPECT_EQ(-1, shift);
211 
212   result = IntegerFrExp(-1.0, &shift);
213   EXPECT_NEAR(-(1 << 30), result, 1);
214   EXPECT_EQ(1, shift);
215 
216   result = IntegerFrExp(123.45, &shift);
217   EXPECT_NEAR(2071147315, result, 1);
218   EXPECT_EQ(7, shift);
219 
220   result = IntegerFrExp(NAN, &shift);
221   EXPECT_NEAR(0, result, 1);
222   EXPECT_EQ(0x7fffffff, shift);
223 
224   result = IntegerFrExp(INFINITY, &shift);
225   EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
226   EXPECT_EQ(0x7fffffff, shift);
227 
228   result = IntegerFrExp(-INFINITY, &shift);
229   EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
230   EXPECT_EQ(0x7fffffff, shift);
231 }
232 
TEST(QuantizationUtilTest,IntegerFrExpVersusDouble)233 TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
234   int shift;
235   int32_t result = IntegerFrExp(0.0, &shift);
236   EXPECT_EQ(result, 0);
237   EXPECT_EQ(shift, 0);
238 
239   int double_shift;
240   double double_result = std::frexp(0.0, &double_shift);
241   EXPECT_EQ(double_result, 0);
242   EXPECT_EQ(double_shift, 0);
243 
244   result = IntegerFrExp(1.0, &shift);
245   EXPECT_NEAR(result, 0x40000000, 1);
246   EXPECT_EQ(shift, 1);
247   double_result = std::frexp(1.0, &double_shift);
248   EXPECT_NEAR(double_result, 0.5, 1e-5);
249   EXPECT_EQ(double_shift, 1);
250 
251   result = IntegerFrExp(0.25, &shift);
252   EXPECT_NEAR(result, 0x40000000, 1);
253   EXPECT_EQ(shift, -1);
254   double_result = std::frexp(0.25, &double_shift);
255   EXPECT_NEAR(double_result, 0.5, 1e-5);
256   EXPECT_EQ(double_shift, -1);
257 
258   result = IntegerFrExp(-1.0, &shift);
259   EXPECT_NEAR(result, -(1 << 30), 1);
260   EXPECT_EQ(shift, 1);
261   double_result = std::frexp(-1.0, &double_shift);
262   EXPECT_NEAR(double_result, -0.5, 1e-5);
263   EXPECT_EQ(double_shift, 1);
264 
265   result = IntegerFrExp(123.45, &shift);
266   EXPECT_NEAR(result, (0.964453 * (1LL << 31)), 1000);
267   EXPECT_EQ(shift, 7);
268   double_result = std::frexp(123.45, &double_shift);
269   EXPECT_NEAR(double_result, 0.964453, 1e-5);
270   EXPECT_EQ(double_shift, 7);
271 }
272 
TEST(QuantizationUtilTest,DoubleFromFractionAndShift)273 TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
274   double result = DoubleFromFractionAndShift(0, 0);
275   EXPECT_EQ(0, result);
276 
277   result = DoubleFromFractionAndShift(0x40000000, 1);
278   EXPECT_NEAR(1.0, result, 1e-5);
279 
280   result = DoubleFromFractionAndShift(0x40000000, 2);
281   EXPECT_NEAR(2.0, result, 1e-5);
282 
283   int shift;
284   int64_t fraction = IntegerFrExp(3.0, &shift);
285   result = DoubleFromFractionAndShift(fraction, shift);
286   EXPECT_NEAR(3.0, result, 1e-5);
287 
288   fraction = IntegerFrExp(123.45, &shift);
289   result = DoubleFromFractionAndShift(fraction, shift);
290   EXPECT_NEAR(123.45, result, 1e-5);
291 
292   fraction = IntegerFrExp(-23.232323, &shift);
293   result = DoubleFromFractionAndShift(fraction, shift);
294   EXPECT_NEAR(-23.232323, result, 1e-5);
295 
296   fraction = IntegerFrExp(NAN, &shift);
297   result = DoubleFromFractionAndShift(fraction, shift);
298   EXPECT_TRUE(std::isnan(result));
299 
300   fraction = IntegerFrExp(INFINITY, &shift);
301   result = DoubleFromFractionAndShift(fraction, shift);
302   EXPECT_FALSE(std::isfinite(result));
303 }
304 
TEST(QuantizationUtilTest,IntegerDoubleMultiply)305 TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
306   EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
307   EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
308   EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
309   EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
310   EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
311   EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
312   EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
313   EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
314   EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
315   EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
316   EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
317   EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
318 }
319 
TEST(QuantizationUtilTest,IntegerDoubleCompare)320 TEST(QuantizationUtilTest, IntegerDoubleCompare) {
321   EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
322   EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
323   EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
324   EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
325   EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
326   EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
327   EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
328   EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
329 }
330 
331 #ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest,ChooseQuantizationParamsInvalidRange)332 TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
333   EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
334 }
335 
TEST(QuantizationUtilTest,QuantizeMultiplierSmallerThanOneExp)336 TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
337   auto quantize = [](double d) {
338     int32_t q;
339     int s;
340     QuantizeMultiplierSmallerThanOneExp(d, &q, &s);
341     return std::pair<int32_t, int>{q, s};
342   };
343 
344   EXPECT_DEATH(quantize(-0.1), "");
345   EXPECT_DEATH(quantize(0.0), "");
346   EXPECT_THAT(quantize(0.25), Pair(1073741824, -1));
347 
348   // Around 0.5 we can see the change in exponent and how we try hard to
349   // void hitting max int32.
350   EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1));
351   EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
352   EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
353 
354   EXPECT_THAT(quantize(0.75), Pair(1610612736, 0));
355   EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0));
356 
357   // If we get close enough to 1.0 it crashes and dies in one of two ways:
358   // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
359   EXPECT_DEATH(quantize(1 - 1e-15), "");
360   EXPECT_DEATH(quantize(1 - 1e-17), "");
361   EXPECT_DEATH(quantize(1.0), "");
362 }
363 
TEST(QuantizationUtilTest,QuantizeMultiplierGreaterThanOne)364 TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) {
365   auto quantize = [](double d) {
366     int32_t q;
367     int s;
368     QuantizeMultiplierGreaterThanOne(d, &q, &s);
369     return std::pair<int32_t, int>{q, s};
370   };
371 
372   // If we are close enough to 1.0 it crashes.
373   EXPECT_DEATH(quantize(1 + 1e-16), "");
374 
375   EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1));
376   EXPECT_THAT(quantize(1.25), Pair(1342177280, 1));
377   EXPECT_THAT(quantize(1.50), Pair(1610612736, 1));
378   EXPECT_THAT(quantize(1.75), Pair(1879048192, 1));
379 
380   // Around the powers of two we see the change in exponent. Also,
381   // we try hard to avoid hitting max int32.
382   EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1));
383   EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2));
384   EXPECT_THAT(quantize(2), Pair(1073741824, 2));
385 }
386 
387 #ifndef __APPLE__  // Some Apple toolchains don't support std::ldexp
TEST(QuantizationUtilTest,QuantizeMultiplierUnderflow)388 TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
389   auto quantize = [](double d) {
390     int32_t q;
391     int s;
392     QuantizeMultiplier(d, &q, &s);
393     return std::pair<int32_t, int>{q, s};
394   };
395 
396   EXPECT_THAT(quantize(std::ldexp(1.0f, -31)), Pair(1073741824, -30));
397   EXPECT_THAT(quantize(std::ldexp(1.0f, -32)), Pair(1073741824, -31));
398   EXPECT_THAT(quantize(std::ldexp(0.99f, -32)), Pair(0, 0));
399   EXPECT_THAT(quantize(std::ldexp(1.0f, -33)), Pair(0, 0));
400 }
401 #endif
402 
TEST(QuantizationUtilTest,GetInvSqrtQuantizedMultiplierExp)403 TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
404   auto inv_sqrt = [](std::int32_t input) {
405     int32_t output;
406     int output_shift;
407     GetInvSqrtQuantizedMultiplierExp(input, 1, &output, &output_shift);
408     return std::pair<int32_t, int>{output, output_shift};
409   };
410 
411   const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
412   EXPECT_THAT(inv_sqrt(0), Pair(kInt32Max, 0));
413   EXPECT_THAT(inv_sqrt(1), Pair(kInt32Max, 0));
414   EXPECT_THAT(inv_sqrt(2), Pair(1518498372, 0));
415   EXPECT_THAT(inv_sqrt(3), Pair(1239850284, 0));
416   EXPECT_THAT(inv_sqrt(4), Pair(1073741828, 0));
417   EXPECT_THAT(inv_sqrt(100), Pair(214748363, 0));
418   EXPECT_THAT(inv_sqrt(10000), Pair(343597361, 4));
419   EXPECT_THAT(inv_sqrt(1000000), Pair(274877901, 7));
420   EXPECT_THAT(inv_sqrt(100000000), Pair(219902323, 10));
421   EXPECT_THAT(inv_sqrt((1 << 30)), Pair(268435457, 12));
422   EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12));
423 }
424 
TEST(QuantizationUtilTest,MultiplyByQuantizedMultiplierInt32)425 TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt32) {
426   auto quant_and_multiply = [](int32_t x, double multiplier) {
427     int32_t quantized_multiplier;
428     int shift;
429     QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
430     return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
431   };
432 
433   EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
434   EXPECT_EQ(quant_and_multiply(1, 0), 0);
435   EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
436   EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
437   EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), 0.00001),
438             -21475);
439   EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), 0.00001),
440             21475);
441 #if !TFLITE_SINGLE_ROUNDING
442   // Single-rounding doesn't support negative multipliers, only test negative
443   // multipliers in double-rounding mode.
444   EXPECT_EQ(quant_and_multiply(10000, -0.00097656), -10);
445   EXPECT_EQ(quant_and_multiply(-10000, -0.00097656), 10);
446   EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), -0.00001),
447             21475);
448   EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), -0.00001),
449             -21475);
450 #endif
451 
452   // Test with maximum possible x and quantized_multiplier
453   const int32_t x = std::numeric_limits<int32_t>::max();
454   const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
455   const int shift = -3;
456   const int32_t expected = static_cast<int32_t>(
457       TfLiteRound(static_cast<int64_t>(x) * quantized_multiplier /
458                   static_cast<double>(1LL << (31 - shift))));
459   EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
460             expected);
461   EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
462             -expected);
463 }
464 
TEST(QuantizationUtilTest,MultiplyByQuantizedMultiplierInt64)465 TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt64) {
466   auto quant_and_multiply = [](int64_t x, double multiplier) {
467     int32_t quantized_multiplier;
468     int shift;
469     QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
470     return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
471   };
472 
473   // Negative multipliers are not supported by the 64-bit
474   // MultiplyByQuantizedMultiplier, only use >= 0 multipliers.
475   EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
476   EXPECT_EQ(quant_and_multiply(1, 0), 0);
477   EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
478   EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
479   EXPECT_EQ(quant_and_multiply(-(1LL << 47), 0.00001), -1407385600);
480   EXPECT_EQ(quant_and_multiply((1LL << 47) - 1, 0.00001), 1407385600);
481 
482   // Test with maximum possible x and quantized_multiplier
483   const int64_t x = (1LL << 47) - 1;
484   const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
485   const int shift = -31;
486   // Expected is around 'x * quantized_multiplier / 2**(31 - shift)' ~= 65536
487   // As there is some rounding error, expected is a bit smaller.
488   const int32_t expected = 65534;
489   EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
490             expected);
491   EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
492             -expected);
493 }
494 
TEST(QuantizationUtilTest,PreprocessSoftmaxScaling)495 TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
496   auto quantize = [](double beta, double scale, int integer_bits) {
497     int32_t q;
498     int s;
499     PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s);
500     return std::pair<int32_t, int>{q, s};
501   };
502 
503 #if TFLITE_SINGLE_ROUNDING
504   // If beta * scale is greater than fits in the number of integer bits, the
505   // result is move near the maximum. Otherwise they quantize as expected.
506   // With 4 integer bits we can represent up to 8.0.
507   EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(2147483646, 30));
508   EXPECT_THAT(quantize(1.0, 4.0, 4), Pair(1073741824, 30));
509   // But with 5 bits we can go further.
510   EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(2147483646, 30));
511   EXPECT_THAT(quantize(2.0, 4.0, 5), Pair(1073741824, 30));
512 #else
513   // If beta * scale is greater than fits in the number of integer bits, the
514   // result is move near the maximum. Otherwise they quantize as expected.
515   // With 4 integer bits we can represent up to 16.0.
516   EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31));
517   EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31));
518   // But with 5 bits we can go further.
519   EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31));
520   EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31));
521 #endif
522 }
523 #endif  // GTEST_HAS_DEATH_TEST
524 
TEST(QuantizationUtilTest,CalculateInputRadius)525 TEST(QuantizationUtilTest, CalculateInputRadius) {
526   EXPECT_EQ(CalculateInputRadius(4, 27), 15);
527   EXPECT_EQ(CalculateInputRadius(3, 27), 14);
528   EXPECT_EQ(CalculateInputRadius(3, 28), 7);
529   EXPECT_EQ(CalculateInputRadius(4, 2), 503316480);
530 }
531 
TEST(QuantizationUtilTest,QuantizeMultiplierArray)532 TEST(QuantizationUtilTest, QuantizeMultiplierArray) {
533   const std::vector<double> weights = {-4,    -2,   -1,  -0.5, -0.25, -0.125, 0,
534                                        0.125, 0.25, 0.5, 1,    2,     4};
535   const int size = weights.size();
536   std::vector<int32> effective_scale_significand(size);
537   std::vector<int> effective_scale_shift(size);
538   QuantizeMultiplierArray(weights.data(), size,
539                           effective_scale_significand.data(),
540                           effective_scale_shift.data());
541   const std::vector<int32> expected_effective_scale_significand = {
542       -1073741824,  // float scale = -4
543       -1073741824,  // float scale = -2
544       -1073741824,  // float scale = -1
545       -1073741824,  // float scale = -0.5
546       -1073741824,  // float scale = -0.25
547       -1073741824,  // float scale = -0.125
548       0,            // float scale = 0
549       1073741824,   // float scale = 0.125
550       1073741824,   // float scale = 0.25
551       1073741824,   // float scale = 0.5
552       1073741824,   // float scale = 1
553       1073741824,   // float scale = 2
554       1073741824,   // float scale = 4
555   };
556 
557   const std::vector<int> expected_effective_scale_shift = {
558       3,   // float scale = -4
559       2,   // float scale = -2
560       1,   // float scale = -1
561       0,   // float scale = -0.5
562       -1,  // float scale = -0.25
563       -2,  // float scale = -0.125
564       0,   // float scale = 0
565       -2,  // float scale = 0.125
566       -1,  // float scale = 0.25
567       0,   // float scale = 0.5
568       1,   // float scale = 1
569       2,   // float scale = 2
570       3,   // float scale = 4
571   };
572   EXPECT_THAT(effective_scale_significand,
573               ElementsAreArray(expected_effective_scale_significand));
574   EXPECT_THAT(effective_scale_shift,
575               ElementsAreArray(expected_effective_scale_shift));
576 }
577 
578 }  // namespace
579 }  // namespace tflite
580