• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
10#include <assert.h>
11
12#include <${SSE_HEADER}>
13
14#include <xnnpack/vunary.h>
15#include <xnnpack/common.h>
16
17
18$ISA = {2: "sse2", 4: "sse41"}[SSE]
19void xnn_f32_velu_ukernel__${ISA}_rr2_p6_x${BATCH_TILE}(
20    size_t n,
21    const float* x,
22    float* y,
23    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
24{
25  assert(n != 0);
26  assert(n % sizeof(float) == 0);
27  assert(x != NULL);
28  assert(y != NULL);
29
30  const __m128 vprescale = _mm_load_ps(params->sse.prescale);
31  const __m128 valpha = _mm_load_ps(params->sse.alpha);
32  const __m128 vbeta = _mm_load_ps(params->sse.beta);
33
34  const __m128 vsat_cutoff = _mm_set1_ps(-0x1.154246p+4f);
35  const __m128 vmagic_bias = _mm_set1_ps(0x1.8000FEp23f);
36  const __m128 vlog2e = _mm_set1_ps(0x1.715476p+0f);
37  const __m128 vminus_ln2_hi = _mm_set1_ps(-0x1.62E440p-1f);
38  const __m128 vminus_ln2_lo = _mm_set1_ps(0x1.0105C6p-21f);
39  const __m128 vc6 = _mm_set1_ps(0x1.6b7338p-10f);
40  const __m128 vc5 = _mm_set1_ps(0x1.12278Ep-7f);
41  const __m128 vc4 = _mm_set1_ps(0x1.555716p-5f);
42  const __m128 vc3 = _mm_set1_ps(0x1.5554B0p-3f);
43  const __m128 vc2 = _mm_set1_ps(0x1.FFFFFEp-2f);
44  const __m128 vone = _mm_set1_ps(1.0f);
45
46  $if BATCH_TILE > 4:
47    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
48      __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
49      $for N in range(4, BATCH_TILE, 4):
50        __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
51      x += ${BATCH_TILE};
52
53      $for N in range(0, BATCH_TILE, 4):
54        const __m128 vz${ABC[N:N+4]} = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx${ABC[N:N+4]}, vprescale));
55
56      $for N in range(0, BATCH_TILE, 4):
57        __m128 vn${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vz${ABC[N:N+4]}, vlog2e), vmagic_bias);
58
59      $for N in range(0, BATCH_TILE, 4):
60        __m128 vs${ABC[N:N+4]} = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn${ABC[N:N+4]}), 23));
61
62      $for N in range(0, BATCH_TILE, 4):
63        vn${ABC[N:N+4]} = _mm_sub_ps(vn${ABC[N:N+4]}, vmagic_bias);
64
65      $for N in range(0, BATCH_TILE, 4):
66        __m128 vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_hi), vz${ABC[N:N+4]});
67
68      $for N in range(0, BATCH_TILE, 4):
69        vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_lo), vt${ABC[N:N+4]});
70
71      $for N in range(0, BATCH_TILE, 4):
72        __m128 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vc6, vt${ABC[N:N+4]}), vc5);
73
74      $for N in range(0, BATCH_TILE, 4):
75        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc4);
76
77      $for N in range(0, BATCH_TILE, 4):
78        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc3);
79
80      $for N in range(0, BATCH_TILE, 4):
81        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc2);
82
83      $for N in range(0, BATCH_TILE, 4):
84        vp${ABC[N:N+4]} = _mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
85
86      $for N in range(0, BATCH_TILE, 4):
87        vt${ABC[N:N+4]} = _mm_mul_ps(vt${ABC[N:N+4]}, vs${ABC[N:N+4]});
88        vs${ABC[N:N+4]} = _mm_sub_ps(vs${ABC[N:N+4]}, vone);
89
90      $for N in range(0, BATCH_TILE, 4):
91        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vt${ABC[N:N+4]});
92
93      $for N in range(0, BATCH_TILE, 4):
94        const __m128 ve${ABC[N:N+4]} = _mm_mul_ps(_mm_add_ps(vp${ABC[N:N+4]}, vs${ABC[N:N+4]}), valpha);
95
96      $for N in range(0, BATCH_TILE, 4):
97        $if SSE < 4:
98          const __m128 vm${ABC[N:N+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx${ABC[N:N+4]})));
99        vx${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vbeta);
100
101      $for N in range(0, BATCH_TILE, 4):
102        $if SSE >= 4:
103          const __m128 vy${ABC[N:N+4]} = _mm_blendv_ps(vx${ABC[N:N+4]}, ve${ABC[N:N+4]}, vx${ABC[N:N+4]});
104        $else:
105          const __m128 vy${ABC[N:N+4]} = _mm_or_ps(_mm_and_ps(ve${ABC[N:N+4]}, vm${ABC[N:N+4]}), _mm_andnot_ps(vm${ABC[N:N+4]}, vx${ABC[N:N+4]}));
106
107      _mm_storeu_ps(y, vy${ABC[0:4]});
108      $for N in range(4, BATCH_TILE, 4):
109        _mm_storeu_ps(y + ${N}, vy${ABC[N:N+4]});
110      y += ${BATCH_TILE};
111    }
112  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
113    __m128 vx = _mm_loadu_ps(x);
114    x += 4;
115
116    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
117
118    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
119    __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23));
120    vn = _mm_sub_ps(vn, vmagic_bias);
121
122    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
123    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
124
125    __m128 vp = _mm_add_ps(_mm_mul_ps(vc6, vt), vc5);
126    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc4);
127    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3);
128    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2);
129    vp = _mm_mul_ps(vp, vt);
130
131    vt = _mm_mul_ps(vt, vs);
132    vs = _mm_sub_ps(vs, vone);
133    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
134    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
135
136    $if SSE < 4:
137      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
138    vx = _mm_mul_ps(vx, vbeta);
139    $if SSE >= 4:
140      const __m128 vy = _mm_blendv_ps(vx, ve, vx);
141    $else:
142      const __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
143
144    _mm_storeu_ps(y, vy);
145    y += 4;
146  }
147  if XNN_UNLIKELY(n != 0) {
148    __m128 vx = _mm_loadu_ps(x);
149
150    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
151
152    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
153    __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23));
154    vn = _mm_sub_ps(vn, vmagic_bias);
155
156    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
157    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
158
159    __m128 vp = _mm_add_ps(_mm_mul_ps(vc6, vt), vc5);
160    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc4);
161    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3);
162    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2);
163    vp = _mm_mul_ps(vp, vt);
164
165    vt = _mm_mul_ps(vt, vs);
166    vs = _mm_sub_ps(vs, vone);
167    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
168    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
169
170    $if SSE < 4:
171      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
172    vx = _mm_mul_ps(vx, vbeta);
173    $if SSE >= 4:
174      __m128 vy = _mm_blendv_ps(vx, ve, vx);
175    $else:
176      __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
177
178    if (n & (2 * sizeof(float))) {
179      _mm_storel_pi((__m64*) y, vy);
180      vy = _mm_movehl_ps(vy, vy);
181      y += 2;
182    }
183    if (n & (1 * sizeof(float))) {
184      _mm_store_ss(y, vy);
185    }
186  }
187}
188