1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
18 #ifdef __AVX2__
19
20 #include <immintrin.h>
21
22 #include <limits>
23
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25
26 namespace tflite {
27 namespace avx2_utils {
28
mm_storeu_si64(void * dst,__m128i v)29 static inline void mm_storeu_si64(void *dst, __m128i v) {
30 #if (defined __clang__) || (defined _MSC_VER)
31 _mm_storeu_si64(dst, v);
32 #else
33 // GCC 9 lacks support for _mm_storeu_si64.
34 *static_cast<std::int64_t *>(dst) = _mm_extract_epi64(v, 0);
35 #endif
36 }
37
mm256_blendv_epi32(const __m256i & a,const __m256i & b,const __m256i & mask)38 static inline __m256i mm256_blendv_epi32(const __m256i &a, const __m256i &b,
39 const __m256i &mask) {
40 __m256 result =
41 _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
42 _mm256_castsi256_ps(mask));
43 return _mm256_castps_si256(result);
44 }
45
rounding_right_shift(const __m256i & value,int32_t right_shift)46 static inline __m256i rounding_right_shift(const __m256i &value,
47 int32_t right_shift) {
48 TFLITE_DCHECK_GT(right_shift, 0);
49 const int32_t one_shift_exp_minus1 = 1 << (right_shift - 1);
50 __m256i nudge = _mm256_set1_epi32(one_shift_exp_minus1);
51 const __m256i r_plus_nudge = _mm256_add_epi32(value, nudge);
52 const __m256i shifted_sum =
53 _mm256_srav_epi32(r_plus_nudge, _mm256_set1_epi32(right_shift));
54
55 // Identify overflow in each lane and create mask.
56 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
57 value, _mm256_set1_epi32(0x7fffffff - one_shift_exp_minus1));
58 // Fill results with either (value + nudge) >> exponent or
59 // std::numeric_limits<std::int32_t>::max() in the case of overflow.
60 return mm256_blendv_epi32(
61 shifted_sum, _mm256_set1_epi32(std::numeric_limits<std::int32_t>::max()),
62 mask_num_plus_nudge_overflow);
63 }
64
rounding_right_shift(const __m256i & value,const __m256i right_shift)65 static inline __m256i rounding_right_shift(const __m256i &value,
66 const __m256i right_shift) {
67 const __m256i zeros = _mm256_setzero_si256();
68 const __m256i mask_rightshift_gtz = _mm256_cmpgt_epi32(right_shift, zeros);
69 const __m256i one_shift_exp_minus1 =
70 _mm256_sllv_epi32(_mm256_set1_epi32(1),
71 _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
72 __m256i nudge =
73 mm256_blendv_epi32(zeros, one_shift_exp_minus1, mask_rightshift_gtz);
74 const __m256i r_plus_nudge = _mm256_add_epi32(value, nudge);
75 const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, right_shift);
76
77 // Identify overflow in each lane and create mask.
78 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
79 value, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
80 // Fill results with either (value + nudge) >> exponent or
81 // std::numeric_limits<std::int32_t>::max() in the case of overflow.
82 return mm256_blendv_epi32(
83 shifted_sum, _mm256_set1_epi32(std::numeric_limits<std::int32_t>::max()),
84 mask_num_plus_nudge_overflow);
85 }
86
CastInt32ToInt16AndStore(int16 * dst,const __m256i v)87 inline void CastInt32ToInt16AndStore(int16 *dst, const __m256i v) {
88 // As _mm256_cvtepi32_epi16 is not supported in AVX2, use the below repack.
89 // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
90 // truncating each 16-bit integer.
91 const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
92 const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
93 mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0));
94 mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
95 }
96
MultiplyByQuantizedMultiplier(const __m256i & value,const int32_t multiplier,const int32_t left_shift)97 inline __m256i MultiplyByQuantizedMultiplier(const __m256i &value,
98 const int32_t multiplier,
99 const int32_t left_shift) {
100 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
101 const __m256i shifted_value =
102 left_shift > 0 ? _mm256_sllv_epi32(value, _mm256_set1_epi32(left_shift))
103 : value;
104
105 __m256i scaled_v_low = _mm256_mul_epi32(
106 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 0)),
107 _mm256_set1_epi64x(multiplier));
108 __m256i scaled_v_high = _mm256_mul_epi32(
109 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 1)),
110 _mm256_set1_epi64x(multiplier));
111
112 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, _mm256_set1_epi64x(31));
113 scaled_v_high = _mm256_srlv_epi64(scaled_v_high, _mm256_set1_epi64x(31));
114 // As _mm256_cvtepi64_epi32 is not supported in AVX2, use the below permute.
115 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
116 __m256i result = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
117 result = _mm256_permutevar8x32_epi32(result, repack_perm);
118 if (left_shift >= 0) {
119 return result;
120 }
121 return rounding_right_shift(result, -left_shift);
122 }
123
MultiplyByQuantizedMultiplier(const __m256i & value,const __m256i multiplier,const __m256i left_shift)124 inline __m256i MultiplyByQuantizedMultiplier(const __m256i &value,
125 const __m256i multiplier,
126 const __m256i left_shift) {
127 const __m256i zero_vector = _mm256_setzero_si256();
128 const __m256i positive_left_shift = _mm256_max_epi32(left_shift, zero_vector);
129 const __m256i positive_right_shift =
130 _mm256_max_epi32(_mm256_sub_epi32(zero_vector, left_shift), zero_vector);
131
132 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
133 const __m256i shifted_value = _mm256_sllv_epi32(value, positive_left_shift);
134
135 const __m256i multiplier_low =
136 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(multiplier, 0));
137 const __m256i multiplier_high =
138 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(multiplier, 1));
139
140 __m256i scaled_v_low = _mm256_mul_epi32(
141 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 0)),
142 multiplier_low);
143 __m256i scaled_v_high = _mm256_mul_epi32(
144 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_value, 1)),
145 multiplier_high);
146
147 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, _mm256_set1_epi64x(31));
148 scaled_v_high = _mm256_srlv_epi64(scaled_v_high, _mm256_set1_epi64x(31));
149 // As _mm256_cvtepi64_epi32 is not supported in AVX2, use the below permute.
150 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
151 __m256i result = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
152 result = _mm256_permutevar8x32_epi32(result, repack_perm);
153
154 return rounding_right_shift(result, positive_right_shift);
155 }
156 } // namespace avx2_utils
157 } // namespace tflite
158
159 #endif // __AVX2__
160 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_AVX2_QUANTIZATION_UTILS_H_
161