• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
10#include <assert.h>
11
12#include <${SSE_HEADER}>
13
14#include <xnnpack/vunary.h>
15#include <xnnpack/common.h>
16
17
18extern XNN_INTERNAL const float xnn_table_exp2minus_k_over_16[16];
19
20$ISA = {2: "sse2", 4: "sse41"}[SSE]
21void xnn_f32_velu_ukernel__${ISA}_rr2_lut16_p3_x${BATCH_TILE}(
22    size_t n,
23    const float* x,
24    float* y,
25    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
26{
27  assert(n != 0);
28  assert(n % sizeof(float) == 0);
29  assert(x != NULL);
30  assert(y != NULL);
31
32  const __m128 vprescale = _mm_load_ps(params->sse2_rr2_lut16_p3.prescale);
33  const __m128 valpha = _mm_load_ps(params->sse2_rr2_lut16_p3.alpha);
34  const __m128 vbeta = _mm_load_ps(params->sse2_rr2_lut16_p3.beta);
35  const __m128 vsat_cutoff = _mm_load_ps(params->sse2_rr2_lut16_p3.sat_cutoff);
36  const __m128 vmagic_bias = _mm_load_ps(params->sse2_rr2_lut16_p3.magic_bias);
37  const __m128 vlog2e = _mm_load_ps(params->sse2_rr2_lut16_p3.log2e);
38  const __m128i vindex_mask = _mm_load_si128((const __m128i*) params->sse2_rr2_lut16_p3.index_mask);
39  const __m128 vminus_ln2_hi = _mm_load_ps(params->sse2_rr2_lut16_p3.minus_ln2_hi);
40  const __m128 vminus_ln2_lo = _mm_load_ps(params->sse2_rr2_lut16_p3.minus_ln2_lo);
41  const __m128 vc3 = _mm_load_ps(params->sse2_rr2_lut16_p3.c3);
42  const __m128 vc2 = _mm_load_ps(params->sse2_rr2_lut16_p3.c2);
43  const __m128 vone = _mm_load_ps(params->sse2_rr2_lut16_p3.one);
44
45  $if BATCH_TILE > 4:
46    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
47      __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
48      $for N in range(4, BATCH_TILE, 4):
49        __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
50      x += ${BATCH_TILE};
51
52      $for N in range(0, BATCH_TILE, 4):
53        const __m128 vz${ABC[N:N+4]} = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx${ABC[N:N+4]}, vprescale));
54
55      $for N in range(0, BATCH_TILE, 4):
56        __m128 vn${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vz${ABC[N:N+4]}, vlog2e), vmagic_bias);
57
58      $for N in range(0, BATCH_TILE, 4):
59        const __m128i vidx${ABC[N:N+4]} = _mm_slli_epi32(_mm_and_si128(_mm_castps_si128(vn${ABC[N:N+4]}), vindex_mask), 2);
60        const __m128i ven${ABC[N:N+4]} = _mm_slli_epi32(_mm_castps_si128(vn${ABC[N:N+4]}), 19);
61
62      #if XNN_ARCH_X86_64
63        $for N in range(0, BATCH_TILE, 4):
64          const uint64_t vidx${ABC[N:N+2]} = (uint64_t) _mm_cvtsi128_si64(vidx${ABC[N:N+4]});
65          $if SSE >= 4:
66            const uint64_t vidx${ABC[N+2:N+4]} = (uint64_t) _mm_extract_epi64(vidx${ABC[N:N+4]}, 1);
67          $else:
68            const uint64_t vidx${ABC[N+2:N+4]} = (uint64_t) _mm_cvtsi128_si64(_mm_unpackhi_epi64(vidx${ABC[N:N+4]}, vidx${ABC[N:N+4]}));
69          const __m128i vl${ABC[N]}   = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx${ABC[N:N+2]})));
70          const __m128i vl${ABC[N+2]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx${ABC[N+2:N+4]})));
71          $if SSE >= 4:
72            const __m128i vl${ABC[N:N+2]} = _mm_insert_epi32(vl${ABC[N]}, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx${ABC[N:N+2]} >> 32))), 1);
73          $else:
74            const __m128i vl${ABC[N+1]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx${ABC[N:N+2]} >> 32))));
75            const __m128i vl${ABC[N:N+2]} = _mm_unpacklo_epi32(vl${ABC[N]}, vl${ABC[N+1]});
76          $if SSE >= 4:
77            const __m128i vl${ABC[N+2:N+4]} = _mm_insert_epi32(vl${ABC[N+2]}, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx${ABC[N+2:N+4]} >> 32))), 1);
78          $else:
79            const __m128i vl${ABC[N+3]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx${ABC[N+2:N+4]} >> 32))));
80            const __m128i vl${ABC[N+2:N+4]} = _mm_unpacklo_epi32(vl${ABC[N+2]}, vl${ABC[N+3]});
81          const __m128i vl${ABC[N:N+4]} = _mm_unpacklo_epi64(vl${ABC[N:N+2]}, vl${ABC[N+2:N+4]});
82      #else  // !XNN_ARCH_X86_64
83        $for N in range(0, BATCH_TILE, 4):
84          const uint32_t vidx${ABC[N]} = (uint32_t) _mm_cvtsi128_si32(vidx${ABC[N:N+4]});
85          const uint32_t vidx${ABC[N+1]} = (uint32_t) _mm_extract_epi16(vidx${ABC[N:N+4]}, 2);
86          const uint32_t vidx${ABC[N+2]} = (uint32_t) _mm_extract_epi16(vidx${ABC[N:N+4]}, 4);
87          const uint32_t vidx${ABC[N+3]} = (uint32_t) _mm_extract_epi16(vidx${ABC[N:N+4]}, 6);
88          const __m128i vl${ABC[N]}   = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N]})));
89          const __m128i vl${ABC[N+2]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N+2]})));
90          $if SSE >= 4:
91            const __m128i vl${ABC[N:N+2]} = _mm_insert_epi32(vl${ABC[N]}, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N+1]})), 1);
92          $else:
93            const __m128i vl${ABC[N+1]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N+1]})));
94            const __m128i vl${ABC[N:N+2]} = _mm_unpacklo_epi32(vl${ABC[N]}, vl${ABC[N+1]});
95          $if SSE >= 4:
96            const __m128i vl${ABC[N+2:N+4]} = _mm_insert_epi32(vl${ABC[N+2]}, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N+3]})), 1);
97          $else:
98            const __m128i vl${ABC[N+3]} = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + vidx${ABC[N+3]})));
99            const __m128i vl${ABC[N+2:N+4]} = _mm_unpacklo_epi32(vl${ABC[N+2]}, vl${ABC[N+3]});
100          const __m128i vl${ABC[N:N+4]} = _mm_unpacklo_epi64(vl${ABC[N:N+2]}, vl${ABC[N+2:N+4]});
101      #endif  // XNN_ARCH_X86_64
102
103      $for N in range(0, BATCH_TILE, 4):
104        vn${ABC[N:N+4]} = _mm_sub_ps(vn${ABC[N:N+4]}, vmagic_bias);
105        __m128 vs${ABC[N:N+4]} = _mm_castsi128_ps(_mm_add_epi32(vl${ABC[N:N+4]}, ven${ABC[N:N+4]}));
106
107      $for N in range(0, BATCH_TILE, 4):
108        __m128 vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_hi), vz${ABC[N:N+4]});
109
110      $for N in range(0, BATCH_TILE, 4):
111        vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_lo), vt${ABC[N:N+4]});
112
113      $for N in range(0, BATCH_TILE, 4):
114        __m128 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vc3, vt${ABC[N:N+4]}), vc2);
115
116      $for N in range(0, BATCH_TILE, 4):
117        vp${ABC[N:N+4]} = _mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
118
119      $for N in range(0, BATCH_TILE, 4):
120        vt${ABC[N:N+4]} = _mm_mul_ps(vt${ABC[N:N+4]}, vs${ABC[N:N+4]});
121        vs${ABC[N:N+4]} = _mm_sub_ps(vs${ABC[N:N+4]}, vone);
122
123      $for N in range(0, BATCH_TILE, 4):
124        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vt${ABC[N:N+4]});
125
126      $for N in range(0, BATCH_TILE, 4):
127        const __m128 ve${ABC[N:N+4]} = _mm_mul_ps(_mm_add_ps(vp${ABC[N:N+4]}, vs${ABC[N:N+4]}), valpha);
128
129      $for N in range(0, BATCH_TILE, 4):
130        $if SSE < 4:
131          const __m128 vm${ABC[N:N+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx${ABC[N:N+4]})));
132        vx${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vbeta);
133
134      $for N in range(0, BATCH_TILE, 4):
135        $if SSE >= 4:
136          const __m128 vy${ABC[N:N+4]} = _mm_blendv_ps(vx${ABC[N:N+4]}, ve${ABC[N:N+4]}, vx${ABC[N:N+4]});
137        $else:
138          const __m128 vy${ABC[N:N+4]} = _mm_or_ps(_mm_and_ps(ve${ABC[N:N+4]}, vm${ABC[N:N+4]}), _mm_andnot_ps(vm${ABC[N:N+4]}, vx${ABC[N:N+4]}));
139
140      _mm_storeu_ps(y, vy${ABC[0:4]});
141      $for N in range(4, BATCH_TILE, 4):
142        _mm_storeu_ps(y + ${N}, vy${ABC[N:N+4]});
143      y += ${BATCH_TILE};
144    }
145  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
146    __m128 vx = _mm_loadu_ps(x);
147    x += 4;
148
149    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
150
151    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
152
153    const __m128i ven = _mm_slli_epi32(_mm_castps_si128(vn), 19);
154    const __m128i vidx = _mm_slli_epi32(_mm_and_si128(_mm_castps_si128(vn), vindex_mask), 2);
155    #if XNN_ARCH_X86_64
156      const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx);
157      $if SSE >= 4:
158        const uint64_t vidx_hi = (uint64_t) _mm_extract_epi64(vidx, 1);
159      $else:
160        const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(_mm_unpackhi_epi64(vidx, vidx));
161      const __m128i vl_ll   = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_lo)));
162      const __m128i vl_hl = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hi)));
163      $if SSE >= 4:
164        const __m128i vl_lo = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lo >> 32))), 1);
165      $else:
166        const __m128i vl_lh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lo >> 32))));
167        const __m128i vl_lo = _mm_unpacklo_epi32(vl_ll, vl_lh);
168      $if SSE >= 4:
169        const __m128i vl_hi = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hi >> 32))), 1);
170      $else:
171        const __m128i vl_hh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hi >> 32))));
172        const __m128i vl_hi = _mm_unpacklo_epi32(vl_hl, vl_hh);
173    #else  // !XNN_ARCH_X86_64
174      const __m128i vl_ll = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx))));
175      const __m128i vl_hl = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 4))));
176      $if SSE >= 4:
177        const __m128i vl_lo = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 2))), 1);
178      $else:
179        const __m128i vl_lh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 2))));
180        const __m128i vl_lo = _mm_unpacklo_epi32(vl_ll, vl_lh);
181      $if SSE >= 4:
182        const __m128i vl_hi = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 6))), 1);
183      $else:
184        const __m128i vl_hh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 6))));
185        const __m128i vl_hi = _mm_unpacklo_epi32(vl_hl, vl_hh);
186    #endif  // XNN_ARCH_X86_64
187    const __m128i vl = _mm_unpacklo_epi64(vl_lo, vl_hi);
188    __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ven));
189    vn = _mm_sub_ps(vn, vmagic_bias);
190
191    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
192    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
193
194    __m128 vp = _mm_add_ps(_mm_mul_ps(vc3, vt), vc2);
195    vp = _mm_mul_ps(vp, vt);
196
197    vt = _mm_mul_ps(vt, vs);
198    vs = _mm_sub_ps(vs, vone);
199    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
200    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
201
202    $if SSE < 4:
203      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
204    vx = _mm_mul_ps(vx, vbeta);
205    $if SSE >= 4:
206      const __m128 vy = _mm_blendv_ps(vx, ve, vx);
207    $else:
208      const __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
209
210    _mm_storeu_ps(y, vy);
211    y += 4;
212  }
213  if XNN_UNLIKELY(n != 0) {
214    __m128 vx = _mm_loadu_ps(x);
215
216    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
217
218    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
219
220    const __m128i ven = _mm_slli_epi32(_mm_castps_si128(vn), 19);
221    const __m128i vidx = _mm_slli_epi32(_mm_and_si128(_mm_castps_si128(vn), vindex_mask), 2);
222    #if XNN_ARCH_X86_64
223      const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx);
224      $if SSE >= 4:
225        const uint64_t vidx_hi = (uint64_t) _mm_extract_epi64(vidx, 1);
226      $else:
227        const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(_mm_unpackhi_epi64(vidx, vidx));
228      const __m128i vl_ll   = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_lo)));
229      const __m128i vl_hl = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) vidx_hi)));
230      $if SSE >= 4:
231        const __m128i vl_lo = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lo >> 32))), 1);
232      $else:
233        const __m128i vl_lh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_lo >> 32))));
234        const __m128i vl_lo = _mm_unpacklo_epi32(vl_ll, vl_lh);
235      $if SSE >= 4:
236        const __m128i vl_hi = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hi >> 32))), 1);
237      $else:
238        const __m128i vl_hh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) (vidx_hi >> 32))));
239        const __m128i vl_hi = _mm_unpacklo_epi32(vl_hl, vl_hh);
240    #else  // !XNN_ARCH_X86_64
241      const __m128i vl_ll = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_cvtsi128_si32(vidx))));
242      const __m128i vl_hl = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 4))));
243      $if SSE >= 4:
244        const __m128i vl_lo = _mm_insert_epi32(vl_ll, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 2))), 1);
245      $else:
246        const __m128i vl_lh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 2))));
247        const __m128i vl_lo = _mm_unpacklo_epi32(vl_ll, vl_lh);
248      $if SSE >= 4:
249        const __m128i vl_hi = _mm_insert_epi32(vl_hl, *((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 6))), 1);
250      $else:
251        const __m128i vl_hh = _mm_cvtsi32_si128(*((const int*) ((uintptr_t) xnn_table_exp2minus_k_over_16 + (uint32_t) _mm_extract_epi16(vidx, 6))));
252        const __m128i vl_hi = _mm_unpacklo_epi32(vl_hl, vl_hh);
253    #endif  // XNN_ARCH_X86_64
254    const __m128i vl = _mm_unpacklo_epi64(vl_lo, vl_hi);
255    __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ven));
256    vn = _mm_sub_ps(vn, vmagic_bias);
257
258    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
259    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
260
261    __m128 vp = _mm_add_ps(_mm_mul_ps(vc3, vt), vc2);
262    vp = _mm_mul_ps(vp, vt);
263
264    vt = _mm_mul_ps(vt, vs);
265    vs = _mm_sub_ps(vs, vone);
266    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
267    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
268
269    $if SSE < 4:
270      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
271    vx = _mm_mul_ps(vx, vbeta);
272    $if SSE >= 4:
273      __m128 vy = _mm_blendv_ps(vx, ve, vx);
274    $else:
275      __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
276
277    if (n & (2 * sizeof(float))) {
278      _mm_storel_pi((__m64*) y, vy);
279      vy = _mm_movehl_ps(vy, vy);
280      y += 2;
281    }
282    if (n & (1 * sizeof(float))) {
283      _mm_store_ss(y, vy);
284    }
285  }
286}
287