1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/quantization_util.h"
16
17 #include <limits>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/kernels/internal/common.h"
22
23 namespace tflite {
24 namespace {
25
26 using ::testing::ElementsAreArray;
27 using ::testing::Pair;
28
29 template <class FloatIn, class IntOut>
RunSafeCastTests()30 void RunSafeCastTests() {
31 const IntOut imax = std::numeric_limits<IntOut>::max();
32 EXPECT_GT(imax, 0);
33 const IntOut imin = std::numeric_limits<IntOut>::min();
34 const bool s = std::numeric_limits<IntOut>::is_signed;
35 if (s) {
36 EXPECT_LT(imin, 0);
37 } else {
38 EXPECT_EQ(0, imin);
39 }
40
41 // Some basic tests.
42 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0.0)), 0);
43 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-0.0)), 0);
44 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0.99)), 0);
45 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.0)), 1);
46 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.01)), 1);
47 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(1.99)), 1);
48 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(2.0)), 2);
49 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(2.01)), 2);
50 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-0.99)), 0);
51 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.0)), s ? -1 : 0);
52 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.01)), s ? -1 : 0);
53 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-1.99)), s ? -1 : 0);
54 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-2.0)), s ? -2 : 0);
55 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-2.01)), s ? -2 : 0);
56 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(117.9)), 117);
57 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(118.0)), 118);
58 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(118.1)), 118);
59 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-117.9)), s ? -117 : 0);
60 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-118.0)), s ? -118 : 0);
61 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(-118.1)), s ? -118 : 0);
62
63 // Some edge cases.
64 EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::max()), imax);
65 EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::lowest()), imin);
66 EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::infinity()), imax);
67 EXPECT_EQ(SafeCast<IntOut>(-std::numeric_limits<FloatIn>::infinity()), imin);
68 EXPECT_EQ(SafeCast<IntOut>(std::numeric_limits<FloatIn>::quiet_NaN()), 0);
69
70 // Some larger numbers.
71 if (sizeof(IntOut) >= 4 && sizeof(FloatIn) > 4) {
72 EXPECT_EQ(SafeCast<IntOut>(static_cast<FloatIn>(0x76543210)), 0x76543210);
73 }
74
75 if (sizeof(FloatIn) > sizeof(IntOut)) {
76 // Check values near imax.
77 EXPECT_EQ(SafeCast<IntOut>(
78 static_cast<FloatIn>(static_cast<FloatIn>(imax) + 0.1)),
79 imax);
80 EXPECT_EQ(SafeCast<IntOut>(
81 static_cast<FloatIn>(static_cast<FloatIn>(imax) + 0.99)),
82 imax);
83 EXPECT_EQ(SafeCast<IntOut>(
84 static_cast<FloatIn>(static_cast<FloatIn>(imax) + 1.0)),
85 imax);
86 EXPECT_EQ(SafeCast<IntOut>(
87 static_cast<FloatIn>(static_cast<FloatIn>(imax) + 1.99)),
88 imax);
89 EXPECT_EQ(SafeCast<IntOut>(
90 static_cast<FloatIn>(static_cast<FloatIn>(imax) + 2.0)),
91 imax);
92 EXPECT_EQ(SafeCast<IntOut>(
93 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 0.1)),
94 imax - 1);
95 EXPECT_EQ(SafeCast<IntOut>(
96 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 0.99)),
97 imax - 1);
98 EXPECT_EQ(SafeCast<IntOut>(
99 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.0)),
100 imax - 1);
101 EXPECT_EQ(SafeCast<IntOut>(
102 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.01)),
103 imax - 2);
104 EXPECT_EQ(SafeCast<IntOut>(
105 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 1.99)),
106 imax - 2);
107 EXPECT_EQ(SafeCast<IntOut>(
108 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 2.0)),
109 imax - 2);
110 EXPECT_EQ(SafeCast<IntOut>(
111 static_cast<FloatIn>(static_cast<FloatIn>(imax) - 2.01)),
112 imax - 3);
113 }
114
115 // Check values considerably larger in magnitude than imin and imax
116 EXPECT_EQ(
117 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 2)),
118 imax);
119 EXPECT_EQ(
120 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 20)),
121 imax);
122 EXPECT_EQ(
123 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imax) * 100)),
124 imax);
125 EXPECT_EQ(
126 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 2)),
127 imin);
128 EXPECT_EQ(
129 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 20)),
130 imin);
131 EXPECT_EQ(
132 SafeCast<IntOut>(static_cast<FloatIn>(static_cast<FloatIn>(imin) * 100)),
133 imin);
134 }
135
TEST(QuantizationUtilTest,SafeCast)136 TEST(QuantizationUtilTest, SafeCast) {
137 RunSafeCastTests<float, int8_t>();
138 RunSafeCastTests<double, int8_t>();
139 RunSafeCastTests<float, int16_t>();
140 RunSafeCastTests<double, int16_t>();
141 RunSafeCastTests<float, int32_t>();
142 RunSafeCastTests<double, int32_t>();
143 RunSafeCastTests<float, int64_t>();
144 RunSafeCastTests<double, int64_t>();
145 RunSafeCastTests<float, uint8_t>();
146 RunSafeCastTests<double, uint8_t>();
147 RunSafeCastTests<float, uint16_t>();
148 RunSafeCastTests<double, uint16_t>();
149 RunSafeCastTests<float, uint32_t>();
150 RunSafeCastTests<double, uint32_t>();
151 RunSafeCastTests<float, uint64_t>();
152 RunSafeCastTests<double, uint64_t>();
153 }
154
155 // Example taken from http://www.tensorflow.org/performance/quantization
156 //
157 // Quantized | Float
158 // --------- | -----
159 // 0 | -10.0
160 // 255 | 30.0
161 // 128 | 10.0
TEST(QuantizationUtilTest,ChooseQuantizationParams)162 TEST(QuantizationUtilTest, ChooseQuantizationParams) {
163 QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0);
164 EXPECT_NEAR(qp.scale, 0.156863, 1e-5);
165 EXPECT_EQ(qp.zero_point, 64);
166 }
167
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroPointOnMinBoundary)168 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
169 QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0);
170 EXPECT_NEAR(qp.scale, 0.117647, 1e-5);
171 EXPECT_EQ(qp.zero_point, 0);
172 }
173
174 #ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroNotInRange)175 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
176 // Assumption is that zero is within the range.
177 EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
178 }
179
TEST(QuantizationUtilTest,ChooseQuantizationParamsEmptyRangePositive)180 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
181 // Assumption is that zero is within the range.
182 EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
183 }
184 #endif // GTEST_HAS_DEATH_TEST
185
TEST(QuantizationUtilTest,ChooseQuantizationParamsEmptyRangeZero)186 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
187 QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
188 EXPECT_NEAR(qp.scale, 0.0, 1e-5);
189 EXPECT_EQ(qp.zero_point, 0);
190 }
191
TEST(QuantizationUtilTest,ChooseQuantizationParamsZeroPointOnMaxBoundary)192 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
193 QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0);
194 EXPECT_NEAR(qp.scale, 0.039216, 1e-5);
195 EXPECT_EQ(qp.zero_point, 255);
196 }
197
TEST(QuantizationUtilTest,IntegerFrExp)198 TEST(QuantizationUtilTest, IntegerFrExp) {
199 int shift;
200 int64_t result = IntegerFrExp(0.0, &shift);
201 EXPECT_EQ(0, result);
202 EXPECT_EQ(0, shift);
203
204 result = IntegerFrExp(1.0, &shift);
205 EXPECT_NEAR(0x40000000, result, 1);
206 EXPECT_EQ(1, shift);
207
208 result = IntegerFrExp(0.25, &shift);
209 EXPECT_NEAR(0x40000000, result, 1);
210 EXPECT_EQ(-1, shift);
211
212 result = IntegerFrExp(-1.0, &shift);
213 EXPECT_NEAR(-(1 << 30), result, 1);
214 EXPECT_EQ(1, shift);
215
216 result = IntegerFrExp(123.45, &shift);
217 EXPECT_NEAR(2071147315, result, 1);
218 EXPECT_EQ(7, shift);
219
220 result = IntegerFrExp(NAN, &shift);
221 EXPECT_NEAR(0, result, 1);
222 EXPECT_EQ(0x7fffffff, shift);
223
224 result = IntegerFrExp(INFINITY, &shift);
225 EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
226 EXPECT_EQ(0x7fffffff, shift);
227
228 result = IntegerFrExp(-INFINITY, &shift);
229 EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
230 EXPECT_EQ(0x7fffffff, shift);
231 }
232
TEST(QuantizationUtilTest,IntegerFrExpVersusDouble)233 TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
234 int shift;
235 int32_t result = IntegerFrExp(0.0, &shift);
236 EXPECT_EQ(result, 0);
237 EXPECT_EQ(shift, 0);
238
239 int double_shift;
240 double double_result = std::frexp(0.0, &double_shift);
241 EXPECT_EQ(double_result, 0);
242 EXPECT_EQ(double_shift, 0);
243
244 result = IntegerFrExp(1.0, &shift);
245 EXPECT_NEAR(result, 0x40000000, 1);
246 EXPECT_EQ(shift, 1);
247 double_result = std::frexp(1.0, &double_shift);
248 EXPECT_NEAR(double_result, 0.5, 1e-5);
249 EXPECT_EQ(double_shift, 1);
250
251 result = IntegerFrExp(0.25, &shift);
252 EXPECT_NEAR(result, 0x40000000, 1);
253 EXPECT_EQ(shift, -1);
254 double_result = std::frexp(0.25, &double_shift);
255 EXPECT_NEAR(double_result, 0.5, 1e-5);
256 EXPECT_EQ(double_shift, -1);
257
258 result = IntegerFrExp(-1.0, &shift);
259 EXPECT_NEAR(result, -(1 << 30), 1);
260 EXPECT_EQ(shift, 1);
261 double_result = std::frexp(-1.0, &double_shift);
262 EXPECT_NEAR(double_result, -0.5, 1e-5);
263 EXPECT_EQ(double_shift, 1);
264
265 result = IntegerFrExp(123.45, &shift);
266 EXPECT_NEAR(result, (0.964453 * (1LL << 31)), 1000);
267 EXPECT_EQ(shift, 7);
268 double_result = std::frexp(123.45, &double_shift);
269 EXPECT_NEAR(double_result, 0.964453, 1e-5);
270 EXPECT_EQ(double_shift, 7);
271 }
272
TEST(QuantizationUtilTest,DoubleFromFractionAndShift)273 TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
274 double result = DoubleFromFractionAndShift(0, 0);
275 EXPECT_EQ(0, result);
276
277 result = DoubleFromFractionAndShift(0x40000000, 1);
278 EXPECT_NEAR(1.0, result, 1e-5);
279
280 result = DoubleFromFractionAndShift(0x40000000, 2);
281 EXPECT_NEAR(2.0, result, 1e-5);
282
283 int shift;
284 int64_t fraction = IntegerFrExp(3.0, &shift);
285 result = DoubleFromFractionAndShift(fraction, shift);
286 EXPECT_NEAR(3.0, result, 1e-5);
287
288 fraction = IntegerFrExp(123.45, &shift);
289 result = DoubleFromFractionAndShift(fraction, shift);
290 EXPECT_NEAR(123.45, result, 1e-5);
291
292 fraction = IntegerFrExp(-23.232323, &shift);
293 result = DoubleFromFractionAndShift(fraction, shift);
294 EXPECT_NEAR(-23.232323, result, 1e-5);
295
296 fraction = IntegerFrExp(NAN, &shift);
297 result = DoubleFromFractionAndShift(fraction, shift);
298 EXPECT_TRUE(std::isnan(result));
299
300 fraction = IntegerFrExp(INFINITY, &shift);
301 result = DoubleFromFractionAndShift(fraction, shift);
302 EXPECT_FALSE(std::isfinite(result));
303 }
304
TEST(QuantizationUtilTest,IntegerDoubleMultiply)305 TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
306 EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
307 EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
308 EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
309 EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
310 EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
311 EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
312 EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
313 EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
314 EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
315 EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
316 EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
317 EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
318 }
319
TEST(QuantizationUtilTest,IntegerDoubleCompare)320 TEST(QuantizationUtilTest, IntegerDoubleCompare) {
321 EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
322 EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
323 EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
324 EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
325 EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
326 EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
327 EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
328 EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
329 }
330
331 #ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest,ChooseQuantizationParamsInvalidRange)332 TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
333 EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
334 }
335
TEST(QuantizationUtilTest,QuantizeMultiplierSmallerThanOneExp)336 TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
337 auto quantize = [](double d) {
338 int32_t q;
339 int s;
340 QuantizeMultiplierSmallerThanOneExp(d, &q, &s);
341 return std::pair<int32_t, int>{q, s};
342 };
343
344 EXPECT_DEATH(quantize(-0.1), "");
345 EXPECT_DEATH(quantize(0.0), "");
346 EXPECT_THAT(quantize(0.25), Pair(1073741824, -1));
347
348 // Around 0.5 we can see the change in exponent and how we try hard to
349 // void hitting max int32.
350 EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1));
351 EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
352 EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
353
354 EXPECT_THAT(quantize(0.75), Pair(1610612736, 0));
355 EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0));
356
357 // If we get close enough to 1.0 it crashes and dies in one of two ways:
358 // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
359 EXPECT_DEATH(quantize(1 - 1e-15), "");
360 EXPECT_DEATH(quantize(1 - 1e-17), "");
361 EXPECT_DEATH(quantize(1.0), "");
362 }
363
TEST(QuantizationUtilTest,QuantizeMultiplierGreaterThanOne)364 TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) {
365 auto quantize = [](double d) {
366 int32_t q;
367 int s;
368 QuantizeMultiplierGreaterThanOne(d, &q, &s);
369 return std::pair<int32_t, int>{q, s};
370 };
371
372 // If we are close enough to 1.0 it crashes.
373 EXPECT_DEATH(quantize(1 + 1e-16), "");
374
375 EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1));
376 EXPECT_THAT(quantize(1.25), Pair(1342177280, 1));
377 EXPECT_THAT(quantize(1.50), Pair(1610612736, 1));
378 EXPECT_THAT(quantize(1.75), Pair(1879048192, 1));
379
380 // Around the powers of two we see the change in exponent. Also,
381 // we try hard to avoid hitting max int32.
382 EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1));
383 EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2));
384 EXPECT_THAT(quantize(2), Pair(1073741824, 2));
385 }
386
387 #ifndef __APPLE__ // Some Apple toolchains don't support std::ldexp
TEST(QuantizationUtilTest,QuantizeMultiplierUnderflow)388 TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
389 auto quantize = [](double d) {
390 int32_t q;
391 int s;
392 QuantizeMultiplier(d, &q, &s);
393 return std::pair<int32_t, int>{q, s};
394 };
395
396 EXPECT_THAT(quantize(std::ldexp(1.0f, -31)), Pair(1073741824, -30));
397 EXPECT_THAT(quantize(std::ldexp(1.0f, -32)), Pair(1073741824, -31));
398 EXPECT_THAT(quantize(std::ldexp(0.99f, -32)), Pair(0, 0));
399 EXPECT_THAT(quantize(std::ldexp(1.0f, -33)), Pair(0, 0));
400 }
401 #endif
402
TEST(QuantizationUtilTest,GetInvSqrtQuantizedMultiplierExp)403 TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
404 auto inv_sqrt = [](std::int32_t input) {
405 int32_t output;
406 int output_shift;
407 GetInvSqrtQuantizedMultiplierExp(input, 1, &output, &output_shift);
408 return std::pair<int32_t, int>{output, output_shift};
409 };
410
411 const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
412 EXPECT_THAT(inv_sqrt(0), Pair(kInt32Max, 0));
413 EXPECT_THAT(inv_sqrt(1), Pair(kInt32Max, 0));
414 EXPECT_THAT(inv_sqrt(2), Pair(1518498372, 0));
415 EXPECT_THAT(inv_sqrt(3), Pair(1239850284, 0));
416 EXPECT_THAT(inv_sqrt(4), Pair(1073741828, 0));
417 EXPECT_THAT(inv_sqrt(100), Pair(214748363, 0));
418 EXPECT_THAT(inv_sqrt(10000), Pair(343597361, 4));
419 EXPECT_THAT(inv_sqrt(1000000), Pair(274877901, 7));
420 EXPECT_THAT(inv_sqrt(100000000), Pair(219902323, 10));
421 EXPECT_THAT(inv_sqrt((1 << 30)), Pair(268435457, 12));
422 EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12));
423 }
424
TEST(QuantizationUtilTest,MultiplyByQuantizedMultiplierInt32)425 TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt32) {
426 auto quant_and_multiply = [](int32_t x, double multiplier) {
427 int32_t quantized_multiplier;
428 int shift;
429 QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
430 return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
431 };
432
433 EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
434 EXPECT_EQ(quant_and_multiply(1, 0), 0);
435 EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
436 EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
437 EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), 0.00001),
438 -21475);
439 EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), 0.00001),
440 21475);
441 #if !TFLITE_SINGLE_ROUNDING
442 // Single-rounding doesn't support negative multipliers, only test negative
443 // multipliers in double-rounding mode.
444 EXPECT_EQ(quant_and_multiply(10000, -0.00097656), -10);
445 EXPECT_EQ(quant_and_multiply(-10000, -0.00097656), 10);
446 EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::min(), -0.00001),
447 21475);
448 EXPECT_EQ(quant_and_multiply(std::numeric_limits<int32_t>::max(), -0.00001),
449 -21475);
450 #endif
451
452 // Test with maximum possible x and quantized_multiplier
453 const int32_t x = std::numeric_limits<int32_t>::max();
454 const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
455 const int shift = -3;
456 const int32_t expected = static_cast<int32_t>(
457 TfLiteRound(static_cast<int64_t>(x) * quantized_multiplier /
458 static_cast<double>(1LL << (31 - shift))));
459 EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
460 expected);
461 EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
462 -expected);
463 }
464
TEST(QuantizationUtilTest,MultiplyByQuantizedMultiplierInt64)465 TEST(QuantizationUtilTest, MultiplyByQuantizedMultiplierInt64) {
466 auto quant_and_multiply = [](int64_t x, double multiplier) {
467 int32_t quantized_multiplier;
468 int shift;
469 QuantizeMultiplier(multiplier, &quantized_multiplier, &shift);
470 return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
471 };
472
473 // Negative multipliers are not supported by the 64-bit
474 // MultiplyByQuantizedMultiplier, only use >= 0 multipliers.
475 EXPECT_EQ(quant_and_multiply(0, 0.1), 0);
476 EXPECT_EQ(quant_and_multiply(1, 0), 0);
477 EXPECT_EQ(quant_and_multiply(10000, 0.00097656), 10);
478 EXPECT_EQ(quant_and_multiply(-10000, 0.00097656), -10);
479 EXPECT_EQ(quant_and_multiply(-(1LL << 47), 0.00001), -1407385600);
480 EXPECT_EQ(quant_and_multiply((1LL << 47) - 1, 0.00001), 1407385600);
481
482 // Test with maximum possible x and quantized_multiplier
483 const int64_t x = (1LL << 47) - 1;
484 const int32_t quantized_multiplier = std::numeric_limits<int32_t>::max();
485 const int shift = -31;
486 // Expected is around 'x * quantized_multiplier / 2**(31 - shift)' ~= 65536
487 // As there is some rounding error, expected is a bit smaller.
488 const int32_t expected = 65534;
489 EXPECT_EQ(MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift),
490 expected);
491 EXPECT_EQ(MultiplyByQuantizedMultiplier(-x, quantized_multiplier, shift),
492 -expected);
493 }
494
TEST(QuantizationUtilTest,PreprocessSoftmaxScaling)495 TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
496 auto quantize = [](double beta, double scale, int integer_bits) {
497 int32_t q;
498 int s;
499 PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s);
500 return std::pair<int32_t, int>{q, s};
501 };
502
503 #if TFLITE_SINGLE_ROUNDING
504 // If beta * scale is greater than fits in the number of integer bits, the
505 // result is move near the maximum. Otherwise they quantize as expected.
506 // With 4 integer bits we can represent up to 8.0.
507 EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(2147483646, 30));
508 EXPECT_THAT(quantize(1.0, 4.0, 4), Pair(1073741824, 30));
509 // But with 5 bits we can go further.
510 EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(2147483646, 30));
511 EXPECT_THAT(quantize(2.0, 4.0, 5), Pair(1073741824, 30));
512 #else
513 // If beta * scale is greater than fits in the number of integer bits, the
514 // result is move near the maximum. Otherwise they quantize as expected.
515 // With 4 integer bits we can represent up to 16.0.
516 EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31));
517 EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31));
518 // But with 5 bits we can go further.
519 EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31));
520 EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31));
521 #endif
522 }
523 #endif // GTEST_HAS_DEATH_TEST
524
TEST(QuantizationUtilTest,CalculateInputRadius)525 TEST(QuantizationUtilTest, CalculateInputRadius) {
526 EXPECT_EQ(CalculateInputRadius(4, 27), 15);
527 EXPECT_EQ(CalculateInputRadius(3, 27), 14);
528 EXPECT_EQ(CalculateInputRadius(3, 28), 7);
529 EXPECT_EQ(CalculateInputRadius(4, 2), 503316480);
530 }
531
TEST(QuantizationUtilTest,QuantizeMultiplierArray)532 TEST(QuantizationUtilTest, QuantizeMultiplierArray) {
533 const std::vector<double> weights = {-4, -2, -1, -0.5, -0.25, -0.125, 0,
534 0.125, 0.25, 0.5, 1, 2, 4};
535 const int size = weights.size();
536 std::vector<int32> effective_scale_significand(size);
537 std::vector<int> effective_scale_shift(size);
538 QuantizeMultiplierArray(weights.data(), size,
539 effective_scale_significand.data(),
540 effective_scale_shift.data());
541 const std::vector<int32> expected_effective_scale_significand = {
542 -1073741824, // float scale = -4
543 -1073741824, // float scale = -2
544 -1073741824, // float scale = -1
545 -1073741824, // float scale = -0.5
546 -1073741824, // float scale = -0.25
547 -1073741824, // float scale = -0.125
548 0, // float scale = 0
549 1073741824, // float scale = 0.125
550 1073741824, // float scale = 0.25
551 1073741824, // float scale = 0.5
552 1073741824, // float scale = 1
553 1073741824, // float scale = 2
554 1073741824, // float scale = 4
555 };
556
557 const std::vector<int> expected_effective_scale_shift = {
558 3, // float scale = -4
559 2, // float scale = -2
560 1, // float scale = -1
561 0, // float scale = -0.5
562 -1, // float scale = -0.25
563 -2, // float scale = -0.125
564 0, // float scale = 0
565 -2, // float scale = 0.125
566 -1, // float scale = 0.25
567 0, // float scale = 0.5
568 1, // float scale = 1
569 2, // float scale = 2
570 3, // float scale = 4
571 };
572 EXPECT_THAT(effective_scale_significand,
573 ElementsAreArray(expected_effective_scale_significand));
574 EXPECT_THAT(effective_scale_shift,
575 ElementsAreArray(expected_effective_scale_shift));
576 }
577
578 } // namespace
579 } // namespace tflite
580