• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7#include <assert.h>
8#include <math.h>
9
10#include <xnnpack/common.h>
11#include <xnnpack/math.h>
12#include <xnnpack/vunary.h>
13
14
15extern XNN_INTERNAL const uint32_t xnn_table_exp2minus_k_over_16[16];
16
17void xnn_f32_velu_ukernel__${"wasm" if WASM else "scalar"}_rr2_lut16_p3_x${BATCH_TILE}(
18    size_t n,
19    const float* x,
20    float* y,
21    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  assert(n % sizeof(float) == 0);
24
25  const float vprescale = params->scalar_rr2_lut16_p3.prescale;
26  const float valpha = params->scalar_rr2_lut16_p3.alpha;
27  const float vbeta = params->scalar_rr2_lut16_p3.beta;
28  const float vmagic_bias = params->scalar_rr2_lut16_p3.magic_bias;
29  const float vlog2e = params->scalar_rr2_lut16_p3.log2e;
30  const uint32_t vindex_mask = UINT32_C(0xF);
31  const float vsat_cutoff = params->scalar_rr2_lut16_p3.sat_cutoff;
32  const float vminus_ln2_hi = params->scalar_rr2_lut16_p3.minus_ln2_hi;
33  const float vminus_ln2_lo = params->scalar_rr2_lut16_p3.minus_ln2_lo;
34  const float vc3 = params->scalar_rr2_lut16_p3.c3;
35  const float vc2 = params->scalar_rr2_lut16_p3.c2;
36  const float vone = params->scalar_rr2_lut16_p3.one;
37
38  $if BATCH_TILE > 1:
39    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
40      $for N in range(BATCH_TILE):
41        float vx${N} = x[${N}];
42      x += ${BATCH_TILE};
43
44      $for N in range(BATCH_TILE):
45        $if WASM:
46          const float vz${N} = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx${N} * vprescale, vsat_cutoff), 0.0f);
47        $else:
48          const float vz${N} = vx${N} * vprescale;
49
50      $for N in range(BATCH_TILE):
51        float vn${N} = vz${N} * vlog2e + vmagic_bias;
52
53      $for N in range(BATCH_TILE):
54        const uint32_t ven${N} = float_as_uint32(vn${N}) << 19;
55        const uint32_t vidx${N} = float_as_uint32(vn${N}) & vindex_mask;
56        vn${N} -= vmagic_bias;
57
58      $for N in range(BATCH_TILE):
59        float vt${N} = vn${N} * vminus_ln2_hi + vz${N};
60        float vs${N} = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx${N}] + ven${N});
61
62      $for N in range(BATCH_TILE):
63        vt${N} = vn${N} * vminus_ln2_lo + vt${N};
64        $if not WASM:
65          if XNN_UNPREDICTABLE(vz${N} <= vsat_cutoff) {
66            vs${N} = 0.0f;
67            vt${N} = 0.0f;
68          }
69
70      $for N in range(BATCH_TILE):
71        float vp${N} = vc3 * vt${N} + vc2;
72
73      $for N in range(BATCH_TILE):
74        vp${N} *= vt${N};
75
76      $for N in range(BATCH_TILE):
77        vt${N} *= vs${N};
78        vs${N} -= vone;
79
80      $for N in range(BATCH_TILE):
81        vp${N} = vp${N} * vt${N} + vt${N};
82
83      $for N in range(BATCH_TILE):
84        const float ve${N} = (vp${N} + vs${N}) * valpha;
85        $if WASM:
86          float vy${N} = __builtin_wasm_max_f32(vx${N} * vbeta, 0.0f);
87        $else:
88          float vy${N} = vx${N} * vbeta;
89
90      $if WASM:
91        $for N in range(BATCH_TILE):
92          vy${N} += __builtin_wasm_min_f32(ve${N}, 0.0f);
93      $else:
94        $for N in range(BATCH_TILE):
95          if XNN_UNPREDICTABLE(vx${N} < 0.0f) {
96            vy${N} = ve${N};
97          }
98
99      $for N in range(BATCH_TILE):
100        y[${N}] = vy${N};
101      y += ${BATCH_TILE};
102    }
103  $if BATCH_TILE == 1:
104    do {
105      float vx = *x++;
106
107      $if WASM:
108        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
109      $else:
110        const float vz = vx * vprescale;
111
112      float vn = vz * vlog2e + vmagic_bias;
113      const uint32_t ven = float_as_uint32(vn) << 19;
114      const uint32_t vidx = float_as_uint32(vn) & vindex_mask;
115      vn -= vmagic_bias;
116
117      float vt = vn * vminus_ln2_hi + vz;
118      float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven);
119
120      vt = vn * vminus_ln2_lo + vt;
121      $if not WASM:
122        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
123          vs = 0.0f;
124          vt = 0.0f;
125        }
126
127      float vp = vc3 * vt + vc2;
128      vp *= vt;
129
130      vt *= vs;
131      vs -= vone;
132      vp = vp * vt + vt;
133      const float ve = (vp + vs) * valpha;
134
135      $if WASM:
136        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
137        vy += __builtin_wasm_min_f32(ve, 0.0f);
138      $else:
139        float vy = vx * vbeta;
140        if XNN_UNPREDICTABLE(vx < 0.0f) {
141          vy = ve;
142        }
143
144      *y++ = vy;
145
146      n -= sizeof(float);
147    } while (n != 0);
148  $elif BATCH_TILE == 2:
149    if XNN_UNLIKELY(n != 0) {
150      float vx = *x;
151
152      $if WASM:
153        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
154      $else:
155        const float vz = vx * vprescale;
156
157      float vn = vz * vlog2e + vmagic_bias;
158      const uint32_t ven = float_as_uint32(vn) << 19;
159      const uint32_t vidx = float_as_uint32(vn) & vindex_mask;
160      vn -= vmagic_bias;
161
162      float vt = vn * vminus_ln2_hi + vz;
163      float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven);
164
165      vt = vn * vminus_ln2_lo + vt;
166      $if not WASM:
167        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
168          vs = 0.0f;
169          vt = 0.0f;
170        }
171
172      float vp = vc3 * vt + vc2;
173      vp *= vt;
174
175      vt *= vs;
176      vs -= vone;
177      vp = vp * vt + vt;
178      const float ve = (vp + vs) * valpha;
179
180      $if WASM:
181        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
182        vy += __builtin_wasm_min_f32(ve, 0.0f);
183      $else:
184        float vy = vx * vbeta;
185        if XNN_UNPREDICTABLE(vx < 0.0f) {
186          vy = ve;
187        }
188
189      *y = vy;
190    }
191  $else:
192    if XNN_UNLIKELY(n != 0) {
193      do {
194        float vx = *x++;
195
196        $if WASM:
197          const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
198        $else:
199          const float vz = vx * vprescale;
200
201        float vn = vz * vlog2e + vmagic_bias;
202        const uint32_t ven = float_as_uint32(vn) << 19;
203        const uint32_t vidx = float_as_uint32(vn) & vindex_mask;
204        vn -= vmagic_bias;
205
206        float vt = vn * vminus_ln2_hi + vz;
207        float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven);
208
209        vt = vn * vminus_ln2_lo + vt;
210        $if not WASM:
211          if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
212            vs = 0.0f;
213            vt = 0.0f;
214          }
215
216        float vp = vc3 * vt + vc2;
217        vp *= vt;
218
219        vt *= vs;
220        vs -= vone;
221        vp = vp * vt + vt;
222        const float ve = (vp + vs) * valpha;
223
224        $if WASM:
225          float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
226          vy += __builtin_wasm_min_f32(ve, 0.0f);
227        $else:
228          float vy = vx * vbeta;
229          if XNN_UNPREDICTABLE(vx < 0.0f) {
230            vy = ve;
231          }
232
233        *y++ = vy;
234
235        n -= sizeof(float);
236      } while (n != 0);
237    }
238}
239