• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7#include <assert.h>
8#include <math.h>
9
10#include <xnnpack/common.h>
11#include <xnnpack/vunary.h>
12
13#include <fp16/bitcasts.h>
14
15
16void xnn_f32_velu_ukernel__${"wasm" if WASM else "scalar"}_rr2_p6_x${BATCH_TILE}(
17    size_t n,
18    const float* x,
19    float* y,
20    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
21{
22  assert(n % sizeof(float) == 0);
23
24  const float vprescale = params->scalar.prescale;
25  const float valpha = params->scalar.alpha;
26  const float vbeta = params->scalar.beta;
27
28  const float vmagic_bias = 0x1.8000FEp23f;
29  const float vlog2e = 0x1.715476p+0f;
30  const float vsat_cutoff = -0x1.154246p+4f;
31  const float vminus_ln2_hi = -0x1.62E440p-1f;
32  const float vminus_ln2_lo = 0x1.0105C6p-21f;
33  const float vc6 = 0x1.6b7338p-10f;
34  const float vc5 = 0x1.12278Ep-7f;
35  const float vc4 = 0x1.555716p-5f;
36  const float vc3 = 0x1.5554B0p-3f;
37  const float vc2 = 0x1.FFFFFEp-2f;
38  const float vone = 1.0f;
39
40  $if BATCH_TILE > 1:
41    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
42      $for N in range(BATCH_TILE):
43        float vx${N} = x[${N}];
44      x += ${BATCH_TILE};
45
46      $for N in range(BATCH_TILE):
47        $if WASM:
48          const float vz${N} = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx${N} * vprescale, vsat_cutoff), 0.0f);
49        $else:
50          const float vz${N} = vx${N} * vprescale;
51
52      $for N in range(BATCH_TILE):
53        float vn${N} = vz${N} * vlog2e + vmagic_bias;
54
55      $for N in range(BATCH_TILE):
56        float vs${N} = fp32_from_bits(fp32_to_bits(vn${N}) << 23);
57        vn${N} -= vmagic_bias;
58
59      $for N in range(BATCH_TILE):
60        float vt${N} = vn${N} * vminus_ln2_hi + vz${N};
61
62      $for N in range(BATCH_TILE):
63        vt${N} = vn${N} * vminus_ln2_lo + vt${N};
64
65      $if not WASM:
66        $for N in range(BATCH_TILE):
67          if XNN_UNPREDICTABLE(vz${N} <= vsat_cutoff) {
68            vs${N} = 0.0f;
69            vt${N} = 0.0f;
70          }
71
72      $for N in range(BATCH_TILE):
73        float vp${N} = vc6 * vt${N} + vc5;
74
75      $for N in range(BATCH_TILE):
76        vp${N} = vp${N} * vt${N} + vc4;
77
78      $for N in range(BATCH_TILE):
79        vp${N} = vp${N} * vt${N} + vc3;
80
81      $for N in range(BATCH_TILE):
82        vp${N} = vp${N} * vt${N} + vc2;
83
84      $for N in range(BATCH_TILE):
85        vp${N} *= vt${N};
86
87      $for N in range(BATCH_TILE):
88        vt${N} *= vs${N};
89        vs${N} -= vone;
90
91      $for N in range(BATCH_TILE):
92        vp${N} = vp${N} * vt${N} + vt${N};
93
94      $for N in range(BATCH_TILE):
95        const float ve${N} = (vp${N} + vs${N}) * valpha;
96        $if WASM:
97          float vy${N} = __builtin_wasm_max_f32(vx${N} * vbeta, 0.0f);
98        $else:
99          float vy${N} = vx${N} * vbeta;
100
101      $if WASM:
102        $for N in range(BATCH_TILE):
103          vy${N} += __builtin_wasm_min_f32(ve${N}, 0.0f);
104      $else:
105        $for N in range(BATCH_TILE):
106          if XNN_UNPREDICTABLE(vx${N} < 0.0f) {
107            vy${N} = ve${N};
108          }
109
110      $for N in range(BATCH_TILE):
111        y[${N}] = vy${N};
112      y += ${BATCH_TILE};
113    }
114  $if BATCH_TILE == 1:
115    do {
116      float vx = *x++;
117
118      $if WASM:
119        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
120      $else:
121        const float vz = vx * vprescale;
122
123      float vn = vz * vlog2e + vmagic_bias;
124      float vs = fp32_from_bits(fp32_to_bits(vn) << 23);
125      vn -= vmagic_bias;
126
127      float vt = vn * vminus_ln2_hi + vz;
128      vt = vn * vminus_ln2_lo + vt;
129
130      $if not WASM:
131        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
132          vs = 0.0f;
133          vt = 0.0f;
134        }
135
136      float vp = vc6 * vt + vc5;
137      vp = vp * vt + vc4;
138      vp = vp * vt + vc3;
139      vp = vp * vt + vc2;
140      vp *= vt;
141
142      vt *= vs;
143      vs -= vone;
144      vp = vp * vt + vt;
145      const float ve = (vp + vs) * valpha;
146
147      $if WASM:
148        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
149        vy += __builtin_wasm_min_f32(ve, 0.0f);
150      $else:
151        float vy = vx * vbeta;
152        if XNN_UNPREDICTABLE(vx < 0.0f) {
153          vy = ve;
154        }
155
156      *y++ = vy;
157
158      n -= sizeof(float);
159    } while (n != 0);
160  $elif BATCH_TILE == 2:
161    if XNN_UNLIKELY(n != 0) {
162      float vx = *x;
163
164      $if WASM:
165        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
166      $else:
167        const float vz = vx * vprescale;
168
169      float vn = vz * vlog2e + vmagic_bias;
170      float vs = fp32_from_bits(fp32_to_bits(vn) << 23);
171      vn -= vmagic_bias;
172
173      float vt = vn * vminus_ln2_hi + vz;
174      vt = vn * vminus_ln2_lo + vt;
175
176      $if not WASM:
177        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
178          vs = 0.0f;
179          vt = 0.0f;
180        }
181
182      float vp = vc6 * vt + vc5;
183      vp = vp * vt + vc4;
184      vp = vp * vt + vc3;
185      vp = vp * vt + vc2;
186      vp *= vt;
187
188      vt *= vs;
189      vs -= vone;
190      vp = vp * vt + vt;
191      const float ve = (vp + vs) * valpha;
192
193      $if WASM:
194        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
195        vy += __builtin_wasm_min_f32(ve, 0.0f);
196      $else:
197        float vy = vx * vbeta;
198        if XNN_UNPREDICTABLE(vx < 0.0f) {
199          vy = ve;
200        }
201
202      *y = vy;
203    }
204  $else:
205    if XNN_UNLIKELY(n != 0) {
206      do {
207        float vx = *x++;
208
209        $if WASM:
210          const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
211        $else:
212          const float vz = vx * vprescale;
213
214        float vn = vz * vlog2e + vmagic_bias;
215        float vs = fp32_from_bits(fp32_to_bits(vn) << 23);
216        vn -= vmagic_bias;
217
218        float vt = vn * vminus_ln2_hi + vz;
219        vt = vn * vminus_ln2_lo + vt;
220
221        $if not WASM:
222          if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
223            vs = 0.0f;
224            vt = 0.0f;
225          }
226
227        float vp = vc6 * vt + vc5;
228        vp = vp * vt + vc4;
229        vp = vp * vt + vc3;
230        vp = vp * vt + vc2;
231        vp *= vt;
232
233        vt *= vs;
234        vs -= vone;
235        vp = vp * vt + vt;
236        const float ve = (vp + vs) * valpha;
237
238        $if WASM:
239          float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
240          vy += __builtin_wasm_min_f32(ve, 0.0f);
241        $else:
242          float vy = vx * vbeta;
243          if XNN_UNPREDICTABLE(vx < 0.0f) {
244            vy = ve;
245          }
246
247        *y++ = vy;
248
249        n -= sizeof(float);
250      } while (n != 0);
251    }
252}
253