• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
18 #ifdef __AVX2__
19 
20 #include <immintrin.h>
21 
22 #include <limits>
23 
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 
26 namespace tflite {
27 namespace avx2_utils {
28 
mm_storeu_si64(void * dst,__m128i v)29 static inline void mm_storeu_si64(void *dst, __m128i v) {
30 #if (defined __clang__) || (defined _MSC_VER)
31   _mm_storeu_si64(dst, v);
32 #else
33   // GCC 9 lacks support for _mm_storeu_si64.
34   *static_cast<std::int64_t *>(dst) = _mm_extract_epi64(v, 0);
35 #endif
36 }
37 
mm256_blendv_epi32(const __m256i & a,const __m256i & b,const __m256i & mask)38 static inline __m256i mm256_blendv_epi32(const __m256i &a, const __m256i &b,
39                                          const __m256i &mask) {
40   __m256 result =
41       _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
42                        _mm256_castsi256_ps(mask));
43   return _mm256_castps_si256(result);
44 }
45 
rounding_right_shift(const __m256i & value,int32_t right_shift)46 static inline __m256i rounding_right_shift(const __m256i &value,
47                                            int32_t right_shift) {
48   TFLITE_DCHECK_GT(right_shift, 0);
49   const int32_t one_shift_exp_minus1 = 1 << (right_shift - 1);
50   __m256i nudge = _mm256_set1_epi32(one_shift_exp_minus1);
51   const __m256i r_plus_nudge = _mm256_add_epi32(value, nudge);
52   const __m256i shifted_sum =
53       _mm256_srav_epi32(r_plus_nudge, _mm256_set1_epi32(right_shift));
54 
55   // Identify overflow in each lane and create mask.
56   const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
57       value, _mm256_set1_epi32(0x7fffffff - one_shift_exp_minus1));
58   // Fill results with either (value + nudge) >> exponent or
59   // std::numeric_limits<std::int32_t>::max() in the case of overflow.
60   return mm256_blendv_epi32(
61       shifted_sum, _mm256_set1_epi32(std::numeric_limits<std::int32_t>::max()),
62       mask_num_plus_nudge_overflow);
63 }
64 
rounding_right_shift(const __m256i & value,const __m256i right_shift)65 static inline __m256i rounding_right_shift(const __m256i &value,
66                                            const __m256i right_shift) {
67   const __m256i zeros = _mm256_setzero_si256();
68   const __m256i mask_rightshift_gtz = _mm256_cmpgt_epi32(right_shift, zeros);
69   const __m256i one_shift_exp_minus1 =
70       _mm256_sllv_epi32(_mm256_set1_epi32(1),
71                         _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
72   __m256i nudge =
73       mm256_blendv_epi32(zeros, one_shift_exp_minus1, mask_rightshift_gtz);
74   const __m256i r_plus_nudge = _mm256_add_epi32(value, nudge);
75   const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, right_shift);
76 
77   // Identify overflow in each lane and create mask.
78   const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
79       value, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
80   // Fill results with either (value + nudge) >> exponent or
81   // std::numeric_limits<std::int32_t>::max() in the case of overflow.
82   return mm256_blendv_epi32(
83       shifted_sum, _mm256_set1_epi32(std::numeric_limits<std::int32_t>::max()),
84       mask_num_plus_nudge_overflow);
85 }
86 
CastInt32ToInt16AndStore(int16 * dst,const __m256i v)87 inline void CastInt32ToInt16AndStore(int16 *dst, const __m256i v) {
88   // As _mm256_cvtepi32_epi16 is not supported in AVX2, use the below repack.
89   // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
90   // truncating each 16-bit integer.
91   const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
92   const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
93   mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0));
94   mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
95 }
96 
MultiplyByQuantizedMultiplier(const __m256i & value,const int32_t multiplier,const int32_t left_shift)97 inline __m256i MultiplyByQuantizedMultiplier(const __m256i &value,
98                                              const int32_t multiplier,
99                                              const int32_t left_shift) {
100   const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
101   const __m256i shifted_value =
102       left_shift > 0 ? _mm256_sllv_epi32(value, _mm256_set1_epi32(left_shift))
103                      : value;
104 
105   __m256i scaled_v_low = _mm256_mul_epi32(
106       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 0)),
107       _mm256_set1_epi64x(multiplier));
108   __m256i scaled_v_high = _mm256_mul_epi32(
109       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 1)),
110       _mm256_set1_epi64x(multiplier));
111 
112   scaled_v_low = _mm256_srlv_epi64(scaled_v_low, _mm256_set1_epi64x(31));
113   scaled_v_high = _mm256_srlv_epi64(scaled_v_high, _mm256_set1_epi64x(31));
114   // As _mm256_cvtepi64_epi32 is not supported in AVX2, use the below permute.
115   scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
116   __m256i result = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
117   result = _mm256_permutevar8x32_epi32(result, repack_perm);
118   if (left_shift >= 0) {
119     return result;
120   }
121   return rounding_right_shift(result, -left_shift);
122 }
123 
MultiplyByQuantizedMultiplier(const __m256i & value,const __m256i multiplier,const __m256i left_shift)124 inline __m256i MultiplyByQuantizedMultiplier(const __m256i &value,
125                                              const __m256i multiplier,
126                                              const __m256i left_shift) {
127   const __m256i zero_vector = _mm256_setzero_si256();
128   const __m256i positive_left_shift = _mm256_max_epi32(left_shift, zero_vector);
129   const __m256i positive_right_shift =
130       _mm256_max_epi32(_mm256_sub_epi32(zero_vector, left_shift), zero_vector);
131 
132   const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
133   const __m256i shifted_value = _mm256_sllv_epi32(value, positive_left_shift);
134 
135   const __m256i multiplier_low =
136       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(multiplier, 0));
137   const __m256i multiplier_high =
138       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(multiplier, 1));
139 
140   __m256i scaled_v_low = _mm256_mul_epi32(
141       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 0)),
142       multiplier_low);
143   __m256i scaled_v_high = _mm256_mul_epi32(
144       _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 1)),
145       multiplier_high);
146 
147   scaled_v_low = _mm256_srlv_epi64(scaled_v_low, _mm256_set1_epi64x(31));
148   scaled_v_high = _mm256_srlv_epi64(scaled_v_high, _mm256_set1_epi64x(31));
149   // As _mm256_cvtepi64_epi32 is not supported in AVX2, use the below permute.
150   scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
151   __m256i result = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
152   result = _mm256_permutevar8x32_epi32(result, repack_perm);
153 
154   return rounding_right_shift(result, positive_right_shift);
155 }
156 }  // namespace avx2_utils
157 }  // namespace tflite
158 
159 #endif  // __AVX2__
160 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
161