• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/optimized/avx2_quantization_utils.h"
16 
17 #include <gmock/gmock.h>
18 #include "tensorflow/lite/kernels/internal/common.h"
19 
20 #ifdef __AVX2__
21 namespace tflite {
22 namespace avx2_utils {
23 namespace {
24 
25 using ::testing::ElementsAreArray;
26 
FillVectorWithInt32(const std::vector<int32_t> & src)27 __m256i FillVectorWithInt32(const std::vector<int32_t>& src) {
28   return _mm256_set_epi32(src[7], src[6], src[5], src[4], src[3], src[2],
29                           src[1], src[0]);
30 }
31 
CompareWithReferenceValue(std::vector<int32_t> & reference_values,const __m256i & result)32 void CompareWithReferenceValue(std::vector<int32_t>& reference_values,
33                                const __m256i& result) {
34   // As _mm256_extract_epi32 only supports const int, which should be known
35   // at the comile time, it puts down 8 comparison instead of for-loop.
36   EXPECT_NEAR(reference_values[0], _mm256_extract_epi32(result, 0), 1);
37   EXPECT_NEAR(reference_values[1], _mm256_extract_epi32(result, 1), 1);
38   EXPECT_NEAR(reference_values[2], _mm256_extract_epi32(result, 2), 1);
39   EXPECT_NEAR(reference_values[3], _mm256_extract_epi32(result, 3), 1);
40   EXPECT_NEAR(reference_values[4], _mm256_extract_epi32(result, 4), 1);
41   EXPECT_NEAR(reference_values[5], _mm256_extract_epi32(result, 5), 1);
42   EXPECT_NEAR(reference_values[6], _mm256_extract_epi32(result, 6), 1);
43   EXPECT_NEAR(reference_values[7], _mm256_extract_epi32(result, 7), 1);
44 }
45 
TEST(CastInt32ToInt16AndStoreTest,CastInt32ToInt16AndStoreTest)46 TEST(CastInt32ToInt16AndStoreTest, CastInt32ToInt16AndStoreTest) {
47   const std::vector<int16_t> src = {1, 2, 3, 4, 5, 6, 7, 8};
48   int16_t dst[8];
49   const __m256i src_vector = _mm256_set_epi32(src[7], src[6], src[5], src[4],
50                                               src[3], src[2], src[1], src[0]);
51   CastInt32ToInt16AndStore(dst, src_vector);
52   EXPECT_THAT(src, ElementsAreArray(dst));
53 }
54 
TEST(MultiplyByQuantizedMultiplierTest,PositiveLeftShiftTest)55 TEST(MultiplyByQuantizedMultiplierTest, PositiveLeftShiftTest) {
56   std::vector<int32_t> values = {100, 200, 300, 400, 500, 600, 700, 800};
57   const __m256i src_vector = FillVectorWithInt32(values);
58   const int32_t left_shift = 20;
59   const int32_t multiplier = 12345;
60   const __m256i result =
61       MultiplyByQuantizedMultiplier(src_vector, multiplier, left_shift);
62 
63   // Get the reference values.
64   for (int i = 0; i < values.size(); i++) {
65     values[i] = tflite::MultiplyByQuantizedMultiplier(values[i], multiplier,
66                                                       left_shift);
67   }
68 
69   CompareWithReferenceValue(values, result);
70 }
71 
TEST(MultiplyByQuantizedMultiplierTest,NegativeLeftShiftTest)72 TEST(MultiplyByQuantizedMultiplierTest, NegativeLeftShiftTest) {
73   std::vector<int32_t> values = {1000, 2000, 3000, 4000,
74                                  5000, 6000, 7000, 8000};
75   const __m256i src_vector = FillVectorWithInt32(values);
76   const int32_t left_shift = -3;
77   const int32_t multiplier = 1234567890;
78   const __m256i result =
79       MultiplyByQuantizedMultiplier(src_vector, multiplier, left_shift);
80 
81   // Get the reference values.
82   for (int i = 0; i < values.size(); i++) {
83     values[i] = tflite::MultiplyByQuantizedMultiplier(values[i], multiplier,
84                                                       left_shift);
85   }
86 
87   CompareWithReferenceValue(values, result);
88 }
89 
TEST(MultiplyByQuantizedMultiplierTest,VectorPositiveLeftShiftTest)90 TEST(MultiplyByQuantizedMultiplierTest, VectorPositiveLeftShiftTest) {
91   std::vector<int32_t> values = {100, 200, 300, 400, 500, 600, 700, 800};
92   const std::vector<int32_t> left_shifts = {20, 19, 18, 17, 16, 15, 14, 13};
93   const std::vector<int32_t> multipliers = {10000, 20000, 30000, 40000,
94                                             50000, 60000, 70000, 80000};
95   const __m256i src_vector = FillVectorWithInt32(values);
96   const __m256i left_shifts_vector = FillVectorWithInt32(left_shifts);
97   const __m256i multipliers_vector = FillVectorWithInt32(multipliers);
98 
99   const __m256i result = MultiplyByQuantizedMultiplier(
100       src_vector, multipliers_vector, left_shifts_vector);
101 
102   // Get the reference values.
103   for (int i = 0; i < values.size(); i++) {
104     values[i] = tflite::MultiplyByQuantizedMultiplier(values[i], multipliers[i],
105                                                       left_shifts[i]);
106   }
107 
108   CompareWithReferenceValue(values, result);
109 }
110 
TEST(MultiplyByQuantizedMultiplierTest,VectorNegativeLeftShiftTest)111 TEST(MultiplyByQuantizedMultiplierTest, VectorNegativeLeftShiftTest) {
112   std::vector<int32_t> values = {1000, 2000, 3000, 4000,
113                                  5000, 6000, 7000, 8000};
114   const std::vector<int32_t> left_shifts = {-3, -4, -5, -6, -7, -8, -9, -10};
115   const std::vector<int32_t> multipliers = {1000000000, 1100000000, 1200000000,
116                                             1300000000, 1400000000, 1500000000,
117                                             1600000000, 1700000000};
118   const __m256i src_vector = FillVectorWithInt32(values);
119   const __m256i left_shifts_vector = FillVectorWithInt32(left_shifts);
120   const __m256i multipliers_vector = FillVectorWithInt32(multipliers);
121 
122   const __m256i result = MultiplyByQuantizedMultiplier(
123       src_vector, multipliers_vector, left_shifts_vector);
124 
125   // Get the reference values.
126   for (int i = 0; i < values.size(); i++) {
127     values[i] = tflite::MultiplyByQuantizedMultiplier(values[i], multipliers[i],
128                                                       left_shifts[i]);
129   }
130 
131   CompareWithReferenceValue(values, result);
132 }
133 
134 }  // namespace
135 }  // namespace avx2_utils
136 }  // namespace tflite
137 
138 #endif  //  __AVX2__
139