1 // Auto-generated file. Do not edit!
2 // Template: src/f32-velu/avx-rr2-lut16-p3.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <immintrin.h>
13
14 #include <xnnpack/common.h>
15 #include <xnnpack/intrinsics-polyfill.h>
16 #include <xnnpack/vunary.h>
17
18
19 extern XNN_INTERNAL const int xnn_table_exp2minus_k_over_16[16];
20
21 static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};
22
xnn_f32_velu_ukernel__avx_rr2_lut16_p3_x24(size_t n,const float * x,float * y,const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS (1)])23 void xnn_f32_velu_ukernel__avx_rr2_lut16_p3_x24(
24 size_t n,
25 const float* x,
26 float* y,
27 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
28 {
29 assert(n % sizeof(float) == 0);
30
31 const __m256 vprescale = _mm256_broadcast_ps((const __m128*) params->sse.prescale);
32 const __m256 valpha = _mm256_broadcast_ps((const __m128*) params->sse.alpha);
33 const __m256 vbeta = _mm256_broadcast_ps((const __m128*) params->sse.beta);
34
35 const __m256 vsat_cutoff = _mm256_set1_ps(-0x1.154246p+4f);
36 const __m256 vmagic_bias = _mm256_set1_ps(0x1.800000p19f);
37 const __m256 vlog2e = _mm256_set1_ps(0x1.715476p+0f);
38 const __m256 vindex_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xF));
39 const __m256 vminus_ln2_hi = _mm256_set1_ps(-0x1.62E400p-1f);
40 const __m256 vminus_ln2_lo = _mm256_set1_ps(-0x1.7F7D1Cp-20f);
41 const __m256 vc3 = _mm256_set1_ps(0x1.55561Cp-3f);
42 const __m256 vc2 = _mm256_set1_ps(0x1.0001ECp-1f);
43 const __m256 vone = _mm256_set1_ps(1.0f);
44
45 for (; n >= 24 * sizeof(float); n -= 24 * sizeof(float)) {
46 __m256 vx0 = _mm256_loadu_ps(x);
47 __m256 vx1 = _mm256_loadu_ps(x + 8);
48 __m256 vx2 = _mm256_loadu_ps(x + 16);
49 x += 24;
50
51 const __m256 vz0 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx0, vprescale));
52 const __m256 vz1 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx1, vprescale));
53 const __m256 vz2 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx2, vprescale));
54
55 __m256 vn0 = _mm256_add_ps(_mm256_mul_ps(vz0, vlog2e), vmagic_bias);
56 __m256 vn1 = _mm256_add_ps(_mm256_mul_ps(vz1, vlog2e), vmagic_bias);
57 __m256 vn2 = _mm256_add_ps(_mm256_mul_ps(vz2, vlog2e), vmagic_bias);
58
59 const __m256 vidx0 = _mm256_and_ps(vn0, vindex_mask);
60
61 const __m128i vidx0_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vidx0)), 2);
62 const __m128i vidx0_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vidx0, 1)), 2);
63 #if XNN_ARCH_X86_64
64 const uint64_t vidx0_ll = (uint64_t) _mm_cvtsi128_si64(vidx0_lo);
65 const uint64_t vidx0_lh = (uint64_t) _mm_extract_epi64(vidx0_lo, 1);
66 const uint64_t vidx0_hl = (uint64_t) _mm_cvtsi128_si64(vidx0_hi);
67 const uint64_t vidx0_hh = (uint64_t) _mm_extract_epi64(vidx0_hi, 1);
68 __m128i vl0_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx0_ll));
69 __m128i vl0_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx0_lh));
70 __m128i vl0_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx0_hl));
71 __m128i vl0_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx0_hh));
72 vl0_ll = _mm_insert_epi32(vl0_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx0_ll >> 32))), 1);
73 vl0_lh = _mm_insert_epi32(vl0_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx0_lh >> 32))), 1);
74 vl0_hl = _mm_insert_epi32(vl0_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx0_hl >> 32))), 1);
75 vl0_hh = _mm_insert_epi32(vl0_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx0_hh >> 32))), 1);
76 #else
77 __m128i vl0_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx0_lo)));
78 __m128i vl0_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_lo, 2)));
79 __m128i vl0_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx0_hi)));
80 __m128i vl0_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_hi, 2)));
81 vl0_ll = _mm_insert_epi32(vl0_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_lo, 1))), 1);
82 vl0_lh = _mm_insert_epi32(vl0_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_lo, 3))), 1);
83 vl0_hl = _mm_insert_epi32(vl0_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_hi, 1))), 1);
84 vl0_hh = _mm_insert_epi32(vl0_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx0_hi, 3))), 1);
85 #endif
86 const __m128i vl0_lo = _mm_unpacklo_epi64(vl0_ll, vl0_lh);
87 const __m128i vl0_hi = _mm_unpacklo_epi64(vl0_hl, vl0_hh);
88 const __m256 vidx1 = _mm256_and_ps(vn1, vindex_mask);
89
90 const __m128i vidx1_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vidx1)), 2);
91 const __m128i vidx1_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vidx1, 1)), 2);
92 #if XNN_ARCH_X86_64
93 const uint64_t vidx1_ll = (uint64_t) _mm_cvtsi128_si64(vidx1_lo);
94 const uint64_t vidx1_lh = (uint64_t) _mm_extract_epi64(vidx1_lo, 1);
95 const uint64_t vidx1_hl = (uint64_t) _mm_cvtsi128_si64(vidx1_hi);
96 const uint64_t vidx1_hh = (uint64_t) _mm_extract_epi64(vidx1_hi, 1);
97 __m128i vl1_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx1_ll));
98 __m128i vl1_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx1_lh));
99 __m128i vl1_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx1_hl));
100 __m128i vl1_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx1_hh));
101 vl1_ll = _mm_insert_epi32(vl1_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx1_ll >> 32))), 1);
102 vl1_lh = _mm_insert_epi32(vl1_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx1_lh >> 32))), 1);
103 vl1_hl = _mm_insert_epi32(vl1_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx1_hl >> 32))), 1);
104 vl1_hh = _mm_insert_epi32(vl1_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx1_hh >> 32))), 1);
105 #else
106 __m128i vl1_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx1_lo)));
107 __m128i vl1_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_lo, 2)));
108 __m128i vl1_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx1_hi)));
109 __m128i vl1_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_hi, 2)));
110 vl1_ll = _mm_insert_epi32(vl1_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_lo, 1))), 1);
111 vl1_lh = _mm_insert_epi32(vl1_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_lo, 3))), 1);
112 vl1_hl = _mm_insert_epi32(vl1_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_hi, 1))), 1);
113 vl1_hh = _mm_insert_epi32(vl1_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx1_hi, 3))), 1);
114 #endif
115 const __m128i vl1_lo = _mm_unpacklo_epi64(vl1_ll, vl1_lh);
116 const __m128i vl1_hi = _mm_unpacklo_epi64(vl1_hl, vl1_hh);
117 const __m256 vidx2 = _mm256_and_ps(vn2, vindex_mask);
118
119 const __m128i vidx2_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vidx2)), 2);
120 const __m128i vidx2_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vidx2, 1)), 2);
121 #if XNN_ARCH_X86_64
122 const uint64_t vidx2_ll = (uint64_t) _mm_cvtsi128_si64(vidx2_lo);
123 const uint64_t vidx2_lh = (uint64_t) _mm_extract_epi64(vidx2_lo, 1);
124 const uint64_t vidx2_hl = (uint64_t) _mm_cvtsi128_si64(vidx2_hi);
125 const uint64_t vidx2_hh = (uint64_t) _mm_extract_epi64(vidx2_hi, 1);
126 __m128i vl2_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx2_ll));
127 __m128i vl2_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx2_lh));
128 __m128i vl2_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx2_hl));
129 __m128i vl2_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx2_hh));
130 vl2_ll = _mm_insert_epi32(vl2_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx2_ll >> 32))), 1);
131 vl2_lh = _mm_insert_epi32(vl2_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx2_lh >> 32))), 1);
132 vl2_hl = _mm_insert_epi32(vl2_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx2_hl >> 32))), 1);
133 vl2_hh = _mm_insert_epi32(vl2_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx2_hh >> 32))), 1);
134 #else
135 __m128i vl2_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx2_lo)));
136 __m128i vl2_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_lo, 2)));
137 __m128i vl2_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx2_hi)));
138 __m128i vl2_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_hi, 2)));
139 vl2_ll = _mm_insert_epi32(vl2_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_lo, 1))), 1);
140 vl2_lh = _mm_insert_epi32(vl2_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_lo, 3))), 1);
141 vl2_hl = _mm_insert_epi32(vl2_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_hi, 1))), 1);
142 vl2_hh = _mm_insert_epi32(vl2_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx2_hi, 3))), 1);
143 #endif
144 const __m128i vl2_lo = _mm_unpacklo_epi64(vl2_ll, vl2_lh);
145 const __m128i vl2_hi = _mm_unpacklo_epi64(vl2_hl, vl2_hh);
146
147 const __m128i ven0_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn0)), 19);
148 const __m128i ven0_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn0, 1)), 19);
149 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
150 const __m128 vs0_lo = _mm_castsi128_ps(_mm_add_epi32(vl0_lo, ven0_lo));
151 const __m128 vs0_hi = _mm_castsi128_ps(_mm_add_epi32(vl0_hi, ven0_hi));
152 const __m128i ven1_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn1)), 19);
153 const __m128i ven1_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn1, 1)), 19);
154 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
155 const __m128 vs1_lo = _mm_castsi128_ps(_mm_add_epi32(vl1_lo, ven1_lo));
156 const __m128 vs1_hi = _mm_castsi128_ps(_mm_add_epi32(vl1_hi, ven1_hi));
157 const __m128i ven2_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn2)), 19);
158 const __m128i ven2_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn2, 1)), 19);
159 vn2 = _mm256_sub_ps(vn2, vmagic_bias);
160 const __m128 vs2_lo = _mm_castsi128_ps(_mm_add_epi32(vl2_lo, ven2_lo));
161 const __m128 vs2_hi = _mm_castsi128_ps(_mm_add_epi32(vl2_hi, ven2_hi));
162
163 __m256 vt0 = _mm256_add_ps(_mm256_mul_ps(vn0, vminus_ln2_hi), vz0);
164 __m256 vt1 = _mm256_add_ps(_mm256_mul_ps(vn1, vminus_ln2_hi), vz1);
165 __m256 vt2 = _mm256_add_ps(_mm256_mul_ps(vn2, vminus_ln2_hi), vz2);
166
167 vt0 = _mm256_add_ps(_mm256_mul_ps(vn0, vminus_ln2_lo), vt0);
168 __m256 vs0 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs0_lo), vs0_hi, 1);
169 vt1 = _mm256_add_ps(_mm256_mul_ps(vn1, vminus_ln2_lo), vt1);
170 __m256 vs1 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs1_lo), vs1_hi, 1);
171 vt2 = _mm256_add_ps(_mm256_mul_ps(vn2, vminus_ln2_lo), vt2);
172 __m256 vs2 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs2_lo), vs2_hi, 1);
173
174 __m256 vp0 = _mm256_add_ps(_mm256_mul_ps(vc3, vt0), vc2);
175 __m256 vp1 = _mm256_add_ps(_mm256_mul_ps(vc3, vt1), vc2);
176 __m256 vp2 = _mm256_add_ps(_mm256_mul_ps(vc3, vt2), vc2);
177
178 vp0 = _mm256_mul_ps(vp0, vt0);
179 vp1 = _mm256_mul_ps(vp1, vt1);
180 vp2 = _mm256_mul_ps(vp2, vt2);
181
182 vt0 = _mm256_mul_ps(vt0, vs0);
183 vs0 = _mm256_sub_ps(vs0, vone);
184 vt1 = _mm256_mul_ps(vt1, vs1);
185 vs1 = _mm256_sub_ps(vs1, vone);
186 vt2 = _mm256_mul_ps(vt2, vs2);
187 vs2 = _mm256_sub_ps(vs2, vone);
188
189 vp0 = _mm256_add_ps(_mm256_mul_ps(vp0, vt0), vt0);
190 vp1 = _mm256_add_ps(_mm256_mul_ps(vp1, vt1), vt1);
191 vp2 = _mm256_add_ps(_mm256_mul_ps(vp2, vt2), vt2);
192
193 const __m256 ve0 = _mm256_mul_ps(_mm256_add_ps(vp0, vs0), valpha);
194 vx0 = _mm256_mul_ps(vx0, vbeta);
195 const __m256 ve1 = _mm256_mul_ps(_mm256_add_ps(vp1, vs1), valpha);
196 vx1 = _mm256_mul_ps(vx1, vbeta);
197 const __m256 ve2 = _mm256_mul_ps(_mm256_add_ps(vp2, vs2), valpha);
198 vx2 = _mm256_mul_ps(vx2, vbeta);
199
200 const __m256 vy0 = _mm256_blendv_ps(vx0, ve0, vx0);
201 const __m256 vy1 = _mm256_blendv_ps(vx1, ve1, vx1);
202 const __m256 vy2 = _mm256_blendv_ps(vx2, ve2, vx2);
203
204 _mm256_storeu_ps(y, vy0);
205 _mm256_storeu_ps(y + 8, vy1);
206 _mm256_storeu_ps(y + 16, vy2);
207 y += 24;
208 }
209 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
210 __m256 vx = _mm256_loadu_ps(x);
211 x += 8;
212
213 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
214
215 __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);
216
217 const __m256 vidx = _mm256_and_ps(vn, vindex_mask);
218
219 const __m128i vidx_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vidx)), 2);
220 const __m128i vidx_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vidx, 1)), 2);
221 #if XNN_ARCH_X86_64
222 const uint64_t vidx_ll = (uint64_t) _mm_cvtsi128_si64(vidx_lo);
223 const uint64_t vidx_lh = (uint64_t) _mm_extract_epi64(vidx_lo, 1);
224 const uint64_t vidx_hl = (uint64_t) _mm_cvtsi128_si64(vidx_hi);
225 const uint64_t vidx_hh = (uint64_t) _mm_extract_epi64(vidx_hi, 1);
226 __m128i vl_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_ll));
227 __m128i vl_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_lh));
228 __m128i vl_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hl));
229 __m128i vl_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hh));
230 vl_ll = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_ll >> 32))), 1);
231 vl_lh = _mm_insert_epi32(vl_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lh >> 32))), 1);
232 vl_hl = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hl >> 32))), 1);
233 vl_hh = _mm_insert_epi32(vl_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hh >> 32))), 1);
234 #else
235 __m128i vl_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx_lo)));
236 __m128i vl_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 2)));
237 __m128i vl_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx_hi)));
238 __m128i vl_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 2)));
239 vl_ll = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 1))), 1);
240 vl_lh = _mm_insert_epi32(vl_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 3))), 1);
241 vl_hl = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 1))), 1);
242 vl_hh = _mm_insert_epi32(vl_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 3))), 1);
243 #endif
244 const __m128i ven_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 19);
245 const __m128i ven_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn, 1)), 19);
246
247 const __m128i vl_lo = _mm_unpacklo_epi64(vl_ll, vl_lh);
248 const __m128i vl_hi = _mm_unpacklo_epi64(vl_hl, vl_hh);
249
250 vn = _mm256_sub_ps(vn, vmagic_bias);
251 const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(vl_lo, ven_lo));
252 const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(vl_hi, ven_hi));
253
254 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_hi), vz);
255 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_lo), vt);
256 __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1);
257
258 __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc3, vt), vc2);
259 vp = _mm256_mul_ps(vp, vt);
260
261 vt = _mm256_mul_ps(vt, vs);
262 vs = _mm256_sub_ps(vs, vone);
263 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vt);
264
265 const __m256 ve = _mm256_mul_ps(_mm256_add_ps(vp, vs), valpha);
266 vx = _mm256_mul_ps(vx, vbeta);
267 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
268
269 _mm256_storeu_ps(y, vy);
270 y += 8;
271 }
272 if XNN_UNLIKELY(n != 0) {
273 assert(n >= 1 * sizeof(float));
274 assert(n <= 7 * sizeof(float));
275 __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n));
276
277 __m256 vx = _mm256_maskload_ps(x, vmask);
278
279 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
280
281 __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);
282
283 const __m256 vidx = _mm256_and_ps(vn, vindex_mask);
284
285 const __m128i vidx_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vidx)), 2);
286 const __m128i vidx_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vidx, 1)), 2);
287 #if XNN_ARCH_X86_64
288 const uint64_t vidx_ll = (uint64_t) _mm_cvtsi128_si64(vidx_lo);
289 const uint64_t vidx_lh = (uint64_t) _mm_extract_epi64(vidx_lo, 1);
290 const uint64_t vidx_hl = (uint64_t) _mm_cvtsi128_si64(vidx_hi);
291 const uint64_t vidx_hh = (uint64_t) _mm_extract_epi64(vidx_hi, 1);
292 __m128i vl_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_ll));
293 __m128i vl_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_lh));
294 __m128i vl_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hl));
295 __m128i vl_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hh));
296 vl_ll = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_ll >> 32))), 1);
297 vl_lh = _mm_insert_epi32(vl_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lh >> 32))), 1);
298 vl_hl = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hl >> 32))), 1);
299 vl_hh = _mm_insert_epi32(vl_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hh >> 32))), 1);
300 #else
301 __m128i vl_ll = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx_lo)));
302 __m128i vl_lh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 2)));
303 __m128i vl_hl = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx_hi)));
304 __m128i vl_hh = _mm_loadu_si32((const void*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 2)));
305 vl_ll = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 1))), 1);
306 vl_lh = _mm_insert_epi32(vl_lh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_lo, 3))), 1);
307 vl_hl = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 1))), 1);
308 vl_hh = _mm_insert_epi32(vl_hh, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi32(vidx_hi, 3))), 1);
309 #endif
310 const __m128i ven_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 19);
311 const __m128i ven_hi = _mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn, 1)), 19);
312
313 const __m128i vl_lo = _mm_unpacklo_epi64(vl_ll, vl_lh);
314 const __m128i vl_hi = _mm_unpacklo_epi64(vl_hl, vl_hh);
315
316 vn = _mm256_sub_ps(vn, vmagic_bias);
317 const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(vl_lo, ven_lo));
318 const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(vl_hi, ven_hi));
319
320 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_hi), vz);
321 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_lo), vt);
322 __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1);
323
324 __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc3, vt), vc2);
325 vp = _mm256_mul_ps(vp, vt);
326
327 vt = _mm256_mul_ps(vt, vs);
328 vs = _mm256_sub_ps(vs, vone);
329 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vt);
330
331 const __m256 ve = _mm256_mul_ps(_mm256_add_ps(vp, vs), valpha);
332 vx = _mm256_mul_ps(vx, vbeta);
333 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
334
335 // _mm256_maskstore_ps(y, vmask, vf) could be used here, but triggers msan failures (probably an msan bug).
336 __m128 vy_lo = _mm256_castps256_ps128(vy);
337 if (n & (4 * sizeof(float))) {
338 _mm_storeu_ps(y, vy_lo);
339 vy_lo = _mm256_extractf128_ps(vy, 1);
340 y += 4;
341 }
342 if (n & (2 * sizeof(float))) {
343 _mm_storel_pi((__m64*) y, vy_lo);
344 vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
345 y += 2;
346 }
347 if (n & (1 * sizeof(float))) {
348 _mm_store_ss(y, vy_lo);
349 }
350 }
351 }
352