• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$assert RR_STEPS in [1, 2]
9$assert DIV_ALGO in ["div", "nr2fma", "nr2recps", "nr1recps1fma"]
10$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
11$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
12#include <assert.h>
13
14#include <arm_neon.h>
15
16#include <xnnpack/common.h>
17#include <xnnpack/vunary.h>
18
19
20void xnn_f32_sigmoid_ukernel__${"neonfma" if FMA else "neon"}_rr${RR_STEPS}_p5_${DIV_ALGO}_x${BATCH_TILE}(
21    size_t n,
22    const float* x,
23    float* y,
24    const void* params) XNN_DISABLE_TSAN
25{
26  assert(n % sizeof(float) == 0);
27
28  const float32x4_t vmagic_bias = vmovq_n_f32(0x1.8000FEp23f);
29  const float32x4_t vminus_log2e = vmovq_n_f32(-0x1.715476p+0f);
30  $if RR_STEPS == 1:
31    const float32x4_t vln2 = vmovq_n_f32(0x1.62E43p-1f);
32  $else:
33    $if FMA:
34      const float32x4_t vln2_hi = vmovq_n_f32(0x1.62E43p-1f);
35      const float32x4_t vln2_lo = vmovq_n_f32(-0x1.05C61p-29f);
36    $else:
37      const float32x4_t vln2_hi = vmovq_n_f32(0x1.62E400p-1f);
38      const float32x4_t vln2_lo = vmovq_n_f32(0x1.7F7D1Cp-20f);
39  const float32x4_t vc5 = vmovq_n_f32(-0x1.0F9F9Cp-7f);
40  const float32x4_t vc4 = vmovq_n_f32(0x1.573A1Ap-5f);
41  const float32x4_t vc3 = vmovq_n_f32(-0x1.555A80p-3f);
42  const float32x4_t vc2 = vmovq_n_f32(0x1.FFFDC6p-2f);
43  const float32x4_t vc1 = vmovq_n_f32(-0x1.FFFFF6p-1f);
44  const float32x4_t vone = vmovq_n_f32(1.0f);
45  const float32x4_t vdenorm_cutoff = vmovq_n_f32(-0x1.5D589Ep+6f);
46
47  $if BATCH_TILE > 4:
48    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
49      $for N in range(0, BATCH_TILE, 4):
50        const float32x4_t vx${ABC[N:N+4]} = vld1q_f32(x); x += 4;
51
52      $for N in range(0, BATCH_TILE, 4):
53        const float32x4_t vz${ABC[N:N+4]} = vabsq_f32(vx${ABC[N:N+4]});
54
55      $for N in range(0, BATCH_TILE, 4):
56        float32x4_t vn${ABC[N:N+4]} = ${VMULADDQ_F32}(vmagic_bias, vz${ABC[N:N+4]}, vminus_log2e);
57
58      $for N in range(0, BATCH_TILE, 4):
59        const float32x4_t vs${ABC[N:N+4]} = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), 23));
60
61      $for N in range(0, BATCH_TILE, 4):
62        vn${ABC[N:N+4]} = vsubq_f32(vn${ABC[N:N+4]}, vmagic_bias);
63
64      $if RR_STEPS == 1:
65        $for N in range(0, BATCH_TILE, 4):
66          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2);
67      $else:
68        $for N in range(0, BATCH_TILE, 4):
69          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2_hi);
70
71        $for N in range(0, BATCH_TILE, 4):
72          vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2_lo);
73
74      $for N in range(0, BATCH_TILE, 4):
75        float32x4_t vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc4, vc5, vt${ABC[N:N+4]});
76
77      $for N in range(0, BATCH_TILE, 4):
78        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc3, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
79
80      $for N in range(0, BATCH_TILE, 4):
81        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc2, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
82
83      $for N in range(0, BATCH_TILE, 4):
84        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vc1, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
85
86      $for N in range(0, BATCH_TILE, 4):
87        vt${ABC[N:N+4]} = vmulq_f32(vt${ABC[N:N+4]}, vs${ABC[N:N+4]});
88
89      $for N in range(0, BATCH_TILE, 4):
90        const float32x4_t ve${ABC[N:N+4]} = ${VMULADDQ_F32}(vs${ABC[N:N+4]}, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
91
92      $for N in range(0, BATCH_TILE, 4):
93        const float32x4_t vd${ABC[N:N+4]} = vaddq_f32(ve${ABC[N:N+4]}, vone);
94
95      $if DIV_ALGO == "div":
96        $for N in range(0, BATCH_TILE, 4):
97          float32x4_t vf${ABC[N:N+4]} = vdivq_f32(ve${ABC[N:N+4]}, vd${ABC[N:N+4]});
98      $else:
99        $for N in range(0, BATCH_TILE, 4):
100          float32x4_t vr${ABC[N:N+4]} = vrecpeq_f32(vd${ABC[N:N+4]});
101
102        $if DIV_ALGO == "nr2fma":
103          $for N in range(0, BATCH_TILE, 4):
104            vr${ABC[N:N+4]} = vfmaq_f32(vr${ABC[N:N+4]}, vr${ABC[N:N+4]}, vfmsq_f32(vone, vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
105        $else:
106          $for N in range(0, BATCH_TILE, 4):
107            vr${ABC[N:N+4]} = vmulq_f32(vr${ABC[N:N+4]}, vrecpsq_f32(vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
108
109        $if DIV_ALGO == "nr2recps":
110          $for N in range(0, BATCH_TILE, 4):
111            vr${ABC[N:N+4]} = vmulq_f32(vr${ABC[N:N+4]}, vrecpsq_f32(vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
112        $else:
113          $for N in range(0, BATCH_TILE, 4):
114            vr${ABC[N:N+4]} = vfmaq_f32(vr${ABC[N:N+4]}, vr${ABC[N:N+4]}, vfmsq_f32(vone, vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
115
116        $for N in range(0, BATCH_TILE, 4):
117          float32x4_t vf${ABC[N:N+4]} = vmulq_f32(ve${ABC[N:N+4]}, vr${ABC[N:N+4]});
118
119      $for N in range(0, BATCH_TILE, 4):
120        vf${ABC[N:N+4]} = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf${ABC[N:N+4]}), vcagtq_f32(vx${ABC[N:N+4]}, vdenorm_cutoff)));
121
122      $for N in range(0, BATCH_TILE, 4):
123        const uint32x4_t vm${ABC[N:N+4]} = vcltq_f32(vx${ABC[N:N+4]}, vmovq_n_f32(0.0f));
124
125      $for N in range(0, BATCH_TILE, 4):
126        vf${ABC[N:N+4]} = vbslq_f32(vm${ABC[N:N+4]}, vf${ABC[N:N+4]}, vsubq_f32(vone, vf${ABC[N:N+4]}));
127
128      $for N in range(0, BATCH_TILE, 4):
129        vst1q_f32(y, vf${ABC[N:N+4]}); y += 4;
130    }
131  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
132    const float32x4_t vx = vld1q_f32(x); x += 4;
133
134    const float32x4_t vz = vabsq_f32(vx);
135
136    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vminus_log2e);
137    const float32x4_t vs = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn), 23));
138    vn = vsubq_f32(vn, vmagic_bias);
139    $if RR_STEPS == 1:
140      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2);
141    $else:
142      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2_hi);
143      vt = ${VMULADDQ_F32}(vt, vn, vln2_lo);
144
145    float32x4_t vp = ${VMULADDQ_F32}(vc4, vc5, vt);
146    vp = ${VMULADDQ_F32}(vc3, vp, vt);
147    vp = ${VMULADDQ_F32}(vc2, vp, vt);
148    vp = ${VMULADDQ_F32}(vc1, vp, vt);
149
150    vt = vmulq_f32(vt, vs);
151    const float32x4_t ve = ${VMULADDQ_F32}(vs, vp, vt);
152    const float32x4_t vd = vaddq_f32(ve, vone);
153
154    $if DIV_ALGO == "div":
155      float32x4_t vf = vdivq_f32(ve, vd);
156    $else:
157      float32x4_t vr = vrecpeq_f32(vd);
158      $if DIV_ALGO == "nr2fma":
159        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
160      $else:
161        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
162      $if DIV_ALGO == "nr2recps":
163        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
164      $else:
165        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
166
167      float32x4_t vf = vmulq_f32(ve, vr);
168    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcagtq_f32(vx, vdenorm_cutoff)));
169    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
170    vf = vbslq_f32(vm, vf, vsubq_f32(vone, vf));
171
172    vst1q_f32(y, vf); y += 4;
173  }
174  if XNN_UNLIKELY(n != 0) {
175    const float32x4_t vx = vld1q_f32(x);
176
177    const float32x4_t vz = vabsq_f32(vx);
178
179    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vminus_log2e);
180    const float32x4_t vs = vreinterpretq_f32_s32(vshlq_n_s32(vreinterpretq_s32_f32(vn), 23));
181    vn = vsubq_f32(vn, vmagic_bias);
182    $if RR_STEPS == 1:
183      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2);
184    $else:
185      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2_hi);
186      vt = ${VMULADDQ_F32}(vt, vn, vln2_lo);
187
188    float32x4_t vp = ${VMULADDQ_F32}(vc4, vc5, vt);
189    vp = ${VMULADDQ_F32}(vc3, vp, vt);
190    vp = ${VMULADDQ_F32}(vc2, vp, vt);
191    vp = ${VMULADDQ_F32}(vc1, vp, vt);
192
193    vt = vmulq_f32(vt, vs);
194    const float32x4_t ve = ${VMULADDQ_F32}(vs, vp, vt);
195    const float32x4_t vd = vaddq_f32(ve, vone);
196
197    $if DIV_ALGO == "div":
198      float32x4_t vf = vdivq_f32(ve, vd);
199    $else:
200      float32x4_t vr = vrecpeq_f32(vd);
201      $if DIV_ALGO == "nr2fma":
202        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
203      $else:
204        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
205      $if DIV_ALGO == "nr2recps":
206        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
207      $else:
208        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
209
210      float32x4_t vf = vmulq_f32(ve, vr);
211    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcagtq_f32(vx, vdenorm_cutoff)));
212    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
213    vf = vbslq_f32(vm, vf, vsubq_f32(vone, vf));
214
215    float32x2_t vf_lo = vget_low_f32(vf);
216    if (n & (2 * sizeof(float))) {
217      vst1_f32(y, vf_lo); y += 2;
218      vf_lo = vget_high_f32(vf);
219    }
220    if (n & (1 * sizeof(float))) {
221      vst1_lane_f32(y, vf_lo, 0);
222    }
223  }
224}
225