• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
10#include <assert.h>
11
12#include <arm_neon.h>
13
14#include <xnnpack/vunary.h>
15#include <xnnpack/common.h>
16
17
18$PARAMS_STRUCT = "neonfma_rr1_p6" if FMA else "neon_rr2_p6"
19void xnn_f32_velu_ukernel__${"neonfma" if FMA else "neon"}_rr${1 if FMA else 2}_p6_x${BATCH_TILE}(
20    size_t n,
21    const float* x,
22    float* y,
23    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
24{
25  assert(n != 0);
26  assert(n % sizeof(float) == 0);
27  assert(x != NULL);
28  assert(y != NULL);
29
30  const float32x4_t vprescale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.prescale);
31  const float32x4_t valpha = vld1q_dup_f32(&params->${PARAMS_STRUCT}.alpha);
32  const float32x4_t vbeta = vld1q_dup_f32(&params->${PARAMS_STRUCT}.beta);
33  const float32x4_t vsat_cutoff = vld1q_dup_f32(&params->${PARAMS_STRUCT}.sat_cutoff);
34  const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
35  const float32x4_t vlog2e = vld1q_dup_f32(&params->${PARAMS_STRUCT}.log2e);
36  $if FMA:
37    const float32x4_t vminus_ln2 = vld1q_dup_f32(&params->neonfma_rr1_p6.minus_ln2);
38  $else:
39    const float32x4_t vminus_ln2_hi = vld1q_dup_f32(&params->neon_rr2_p6.minus_ln2_hi);
40    const float32x4_t vminus_ln2_lo = vld1q_dup_f32(&params->neon_rr2_p6.minus_ln2_lo);
41  const float32x4_t vc6 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c6);
42  const float32x4_t vc5 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c5);
43  const float32x4_t vc4 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c4);
44  const float32x4_t vc3 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c3);
45  const float32x4_t vc2 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c2);
46  const float32x4_t vone = vmovq_n_f32(1.0f);
47
48  $if BATCH_TILE > 4:
49    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
50      $for N in range(0, BATCH_TILE, 4):
51        float32x4_t vx${ABC[N:N+4]} = vld1q_f32(x); x += 4;
52
53      $for N in range(0, BATCH_TILE, 4):
54        const float32x4_t vz${ABC[N:N+4]} = vmaxq_f32(vmulq_f32(vx${ABC[N:N+4]}, vprescale), vsat_cutoff);
55
56      $for N in range(0, BATCH_TILE, 4):
57        float32x4_t vn${ABC[N:N+4]} = ${VMULADDQ_F32}(vmagic_bias, vz${ABC[N:N+4]}, vlog2e);
58
59      $for N in range(0, BATCH_TILE, 4):
60        float32x4_t vs${ABC[N:N+4]} = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), 23));
61        vn${ABC[N:N+4]} = vsubq_f32(vn${ABC[N:N+4]}, vmagic_bias);
62
63      $if FMA:
64        $for N in range(0, BATCH_TILE, 4):
65          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2);
66      $else:
67        $for N in range(0, BATCH_TILE, 4):
68          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_hi);
69
70        $for N in range(0, BATCH_TILE, 4):
71          vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_lo);
72
73      $for N in range(0, BATCH_TILE, 4):
74        float32x4_t vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc5, vc6, vt${ABC[N:N+4]});
75
76      $for N in range(0, BATCH_TILE, 4):
77        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc4, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
78
79      $for N in range(0, BATCH_TILE, 4):
80        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc3, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
81
82      $for N in range(0, BATCH_TILE, 4):
83        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc2, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
84
85      $for N in range(0, BATCH_TILE, 4):
86        vp${ABC[N:N+4]} = vmulq_f32(vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
87
88      $for N in range(0, BATCH_TILE, 4):
89        vt${ABC[N:N+4]} = vmulq_f32(vt${ABC[N:N+4]}, vs${ABC[N:N+4]});
90        vs${ABC[N:N+4]} = vsubq_f32(vs${ABC[N:N+4]}, vone);
91
92      $for N in range(0, BATCH_TILE, 4):
93        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
94
95      $for N in range(0, BATCH_TILE, 4):
96        const float32x4_t ve${ABC[N:N+4]} = vmulq_f32(vaddq_f32(vp${ABC[N:N+4]}, vs${ABC[N:N+4]}), valpha);
97
98      $for N in range(0, BATCH_TILE, 4):
99        const uint32x4_t vm${ABC[N:N+4]} = vcltq_f32(vx${ABC[N:N+4]}, vmovq_n_f32(0.0f));
100        vx${ABC[N:N+4]} = vmulq_f32(vx${ABC[N:N+4]}, vbeta);
101
102      $for N in range(0, BATCH_TILE, 4):
103        const float32x4_t vy${ABC[N:N+4]} = vbslq_f32(vm${ABC[N:N+4]}, ve${ABC[N:N+4]}, vx${ABC[N:N+4]});
104
105      $for N in range(0, BATCH_TILE, 4):
106        vst1q_f32(y, vy${ABC[N:N+4]}); y += 4;
107    }
108  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
109    float32x4_t vx = vld1q_f32(x); x += 4;
110
111    const float32x4_t vz = vmaxq_f32(vmulq_f32(vx, vprescale), vsat_cutoff);
112
113    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vlog2e);
114    float32x4_t vs = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn), 23));
115    vn = vsubq_f32(vn, vmagic_bias);
116
117    $if FMA:
118      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vminus_ln2);
119    $else:
120      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vminus_ln2_hi);
121      vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_lo);
122
123    float32x4_t vp = ${VMULADDQ_F32}(vc5, vc6, vt);
124    vp = ${VMULADDQ_F32}(vc4, vp, vt);
125    vp = ${VMULADDQ_F32}(vc3, vp, vt);
126    vp = ${VMULADDQ_F32}(vc2, vp, vt);
127    vp = vmulq_f32(vp, vt);
128
129    vt = vmulq_f32(vt, vs);
130    vs = vsubq_f32(vs, vone);
131    vp = ${VMULADDQ_F32}(vt, vp, vt);
132    const float32x4_t ve = vmulq_f32(vaddq_f32(vp, vs), valpha);
133
134    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
135    vx = vmulq_f32(vx, vbeta);
136    const float32x4_t vy = vbslq_f32(vm, ve, vx);
137
138    vst1q_f32(y, vy); y += 4;
139  }
140  if XNN_UNLIKELY(n != 0) {
141    float32x4_t vx = vld1q_f32(x);
142
143    const float32x4_t vz = vmaxq_f32(vmulq_f32(vx, vprescale), vsat_cutoff);
144
145    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vlog2e);
146    float32x4_t vs = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn), 23));
147    vn = vsubq_f32(vn, vmagic_bias);
148
149    $if FMA:
150      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vminus_ln2);
151    $else:
152      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vminus_ln2_hi);
153      vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_lo);
154
155    float32x4_t vp = ${VMULADDQ_F32}(vc5, vc6, vt);
156    vp = ${VMULADDQ_F32}(vc4, vp, vt);
157    vp = ${VMULADDQ_F32}(vc3, vp, vt);
158    vp = ${VMULADDQ_F32}(vc2, vp, vt);
159    vp = vmulq_f32(vp, vt);
160
161    vt = vmulq_f32(vt, vs);
162    vs = vsubq_f32(vs, vone);
163    vp = ${VMULADDQ_F32}(vt, vp, vt);
164    const float32x4_t ve = vmulq_f32(vaddq_f32(vp, vs), valpha);
165
166    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
167    vx = vmulq_f32(vx, vbeta);
168    const float32x4_t vy = vbslq_f32(vm, ve, vx);
169
170    float32x2_t vy_lo = vget_low_f32(vy);
171    if (n & (2 * sizeof(float))) {
172      vst1_f32(y, vy_lo); y += 2;
173      vy_lo = vget_high_f32(vy);
174    }
175    if (n & (1 * sizeof(float))) {
176      vst1_lane_f32(y, vy_lo, 0);
177    }
178  }
179}
180