• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$assert DIV_ALGO in ["div", "nr2fma", "nr2recps", "nr1recps1fma"]
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
11$VMULSUBQ_F32 = "vfmsq_f32" if FMA else "vmlsq_f32"
12#include <assert.h>
13
14#include <arm_neon.h>
15
16#include <xnnpack/common.h>
17#include <xnnpack/vunary.h>
18
19
20extern XNN_INTERNAL const float xnn_table_exp2minus_k_over_64[64];
21
22$PARAMS_STRUCT = "neonfma_rr1_lut64_p2" if FMA else "neon_rr2_lut64_p2"
23void xnn_f32_vsigmoid_ukernel__${"neonfma" if FMA else "neon"}_rr${1 if FMA else 2}_lut64_p2_${DIV_ALGO}_x${BATCH_TILE}(
24    size_t n,
25    const float* x,
26    float* y,
27    const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
28{
29  assert(n % sizeof(float) == 0);
30
31  const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
32  const float32x4_t vminus_log2e = vld1q_dup_f32(&params->${PARAMS_STRUCT}.minus_log2e);
33  const int32x4_t vindex_mask = vmovq_n_s32(INT32_C(0x3F));
34  $if FMA:
35    const float32x4_t vln2 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.ln2);
36  $else:
37    const float32x4_t vln2_hi = vld1q_dup_f32(&params->${PARAMS_STRUCT}.ln2_hi);
38    const float32x4_t vln2_lo = vld1q_dup_f32(&params->${PARAMS_STRUCT}.ln2_lo);
39  const float32x4_t vc2 = vld1q_dup_f32(&params->${PARAMS_STRUCT}.c2);
40  const float32x4_t vone = vmovq_n_f32(1.0f);
41  const float32x4_t vdenorm_cutoff = vld1q_dup_f32(&params->${PARAMS_STRUCT}.denorm_cutoff);
42
43  $if BATCH_TILE > 4:
44    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
45      $for N in range(0, BATCH_TILE, 4):
46        const float32x4_t vx${ABC[N:N+4]} = vld1q_f32(x); x += 4;
47
48      $for N in range(0, BATCH_TILE, 4):
49        const float32x4_t vz${ABC[N:N+4]} = vabsq_f32(vx${ABC[N:N+4]});
50
51      $for N in range(0, BATCH_TILE, 4):
52        float32x4_t vn${ABC[N:N+4]} = ${VMULADDQ_F32}(vmagic_bias, vz${ABC[N:N+4]}, vminus_log2e);
53
54      $for N in range(0, BATCH_TILE, 4):
55        const int32x4_t ve${ABC[N:N+4]} = vshlq_n_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), 17);
56
57      // Use bits 0:6 bits of n, as integer, as an index for table lookup of l := 2**(n % 64).
58      $for N in range(0, BATCH_TILE, 4):
59        const uint64x2_t vidx${ABC[N:N+4]} = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), vindex_mask));
60
61      $for N in range(0, BATCH_TILE, 4):
62        const uint64_t vidx${ABC[N:N+2]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 0);
63        const uint64_t vidx${ABC[N+2:N+4]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 1);
64        float32x2_t vl${ABC[N:N+2]} = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx${ABC[N:N+2]}]);
65        float32x2_t vl${ABC[N+2:N+4]} = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx${ABC[N+2:N+4]}]);
66
67      $for N in range(0, BATCH_TILE, 4):
68        vl${ABC[N:N+2]} = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx${ABC[N:N+2]} >> 32)], vl${ABC[N:N+2]}, 1);
69        vl${ABC[N+2:N+4]} = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx${ABC[N+2:N+4]} >> 32)], vl${ABC[N+2:N+4]}, 1);
70        const float32x4_t vl${ABC[N:N+4]} = vcombine_f32(vl${ABC[N:N+2]}, vl${ABC[N+2:N+4]});
71
72      $for N in range(0, BATCH_TILE, 4):
73        const float32x4_t vs${ABC[N:N+4]} = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl${ABC[N:N+4]}), ve${ABC[N:N+4]}));
74
75      $for N in range(0, BATCH_TILE, 4):
76        vn${ABC[N:N+4]} = vsubq_f32(vn${ABC[N:N+4]}, vmagic_bias);
77
78      $if FMA:
79        $for N in range(0, BATCH_TILE, 4):
80          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2);
81      $else:
82        $for N in range(0, BATCH_TILE, 4):
83          float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vz${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2_hi);
84
85        $for N in range(0, BATCH_TILE, 4):
86          vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vn${ABC[N:N+4]}, vln2_lo);
87
88      $for N in range(0, BATCH_TILE, 4):
89        float32x4_t vp${ABC[N:N+4]} = vmulq_f32(vt${ABC[N:N+4]}, vc2);
90
91      $for N in range(0, BATCH_TILE, 4):
92        vp${ABC[N:N+4]} = ${VMULSUBQ_F32}(vt${ABC[N:N+4]}, vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
93
94      $for N in range(0, BATCH_TILE, 4):
95        const float32x4_t vy${ABC[N:N+4]} = ${VMULSUBQ_F32}(vs${ABC[N:N+4]}, vs${ABC[N:N+4]}, vp${ABC[N:N+4]});
96
97      $for N in range(0, BATCH_TILE, 4):
98        const float32x4_t vd${ABC[N:N+4]} = vaddq_f32(vy${ABC[N:N+4]}, vone);
99
100      $if DIV_ALGO == "div":
101        $for N in range(0, BATCH_TILE, 4):
102          float32x4_t vf${ABC[N:N+4]} = vdivq_f32(vy${ABC[N:N+4]}, vd${ABC[N:N+4]});
103      $else:
104        $for N in range(0, BATCH_TILE, 4):
105          float32x4_t vr${ABC[N:N+4]} = vrecpeq_f32(vd${ABC[N:N+4]});
106
107        $if DIV_ALGO == "nr2fma":
108          $for N in range(0, BATCH_TILE, 4):
109            vr${ABC[N:N+4]} = vfmaq_f32(vr${ABC[N:N+4]}, vr${ABC[N:N+4]}, vfmsq_f32(vone, vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
110        $else:
111          $for N in range(0, BATCH_TILE, 4):
112            vr${ABC[N:N+4]} = vmulq_f32(vr${ABC[N:N+4]}, vrecpsq_f32(vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
113
114        $if DIV_ALGO == "nr2recps":
115          $for N in range(0, BATCH_TILE, 4):
116            vr${ABC[N:N+4]} = vmulq_f32(vr${ABC[N:N+4]}, vrecpsq_f32(vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
117        $else:
118          $for N in range(0, BATCH_TILE, 4):
119            vr${ABC[N:N+4]} = vfmaq_f32(vr${ABC[N:N+4]}, vr${ABC[N:N+4]}, vfmsq_f32(vone, vr${ABC[N:N+4]}, vd${ABC[N:N+4]}));
120
121        $for N in range(0, BATCH_TILE, 4):
122          float32x4_t vf${ABC[N:N+4]} = vmulq_f32(vy${ABC[N:N+4]}, vr${ABC[N:N+4]});
123
124      $for N in range(0, BATCH_TILE, 4):
125        vf${ABC[N:N+4]} = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf${ABC[N:N+4]}), vcagtq_f32(vx${ABC[N:N+4]}, vdenorm_cutoff)));
126
127      $for N in range(0, BATCH_TILE, 4):
128        const uint32x4_t vm${ABC[N:N+4]} = vcltq_f32(vx${ABC[N:N+4]}, vmovq_n_f32(0.0f));
129
130      $for N in range(0, BATCH_TILE, 4):
131        vf${ABC[N:N+4]} = vbslq_f32(vm${ABC[N:N+4]}, vf${ABC[N:N+4]}, vsubq_f32(vone, vf${ABC[N:N+4]}));
132
133      $for N in range(0, BATCH_TILE, 4):
134        vst1q_f32(y, vf${ABC[N:N+4]}); y += 4;
135    }
136  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
137    const float32x4_t vx = vld1q_f32(x); x += 4;
138
139    const float32x4_t vz = vabsq_f32(vx);
140
141    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vminus_log2e);
142    const int32x4_t ve = vshlq_n_s32(vreinterpretq_s32_f32(vn), 17);
143
144    const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask));
145    const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0);
146    const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1);
147    float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx_lo]);
148    float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx_hi]);
149    vl_lo = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1);
150    vl_hi = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1);
151    const float32x4_t vl = vcombine_f32(vl_lo, vl_hi);
152
153    const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve));
154    vn = vsubq_f32(vn, vmagic_bias);
155    $if FMA:
156      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2);
157    $else:
158      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2_hi);
159      vt = ${VMULADDQ_F32}(vt, vn, vln2_lo);
160
161    float32x4_t vp = vmulq_f32(vt, vc2);
162    vp = ${VMULSUBQ_F32}(vt, vp, vt);
163
164    const float32x4_t vy = ${VMULSUBQ_F32}(vs, vs, vp);
165    const float32x4_t vd = vaddq_f32(vy, vone);
166
167    $if DIV_ALGO == "div":
168      float32x4_t vf = vdivq_f32(vy, vd);
169    $else:
170      float32x4_t vr = vrecpeq_f32(vd);
171      $if DIV_ALGO == "nr2fma":
172        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
173      $else:
174        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
175      $if DIV_ALGO == "nr2recps":
176        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
177      $else:
178        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
179
180      float32x4_t vf = vmulq_f32(vy, vr);
181    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcagtq_f32(vx, vdenorm_cutoff)));
182    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
183    vf = vbslq_f32(vm, vf, vsubq_f32(vone, vf));
184
185    vst1q_f32(y, vf); y += 4;
186  }
187  if XNN_UNLIKELY(n != 0) {
188    const float32x4_t vx = vld1q_f32(x);
189
190    const float32x4_t vz = vabsq_f32(vx);
191
192    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vz, vminus_log2e);
193    const int32x4_t ve = vshlq_n_s32(vreinterpretq_s32_f32(vn), 17);
194
195    const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask));
196    const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0);
197    const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1);
198    float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx_lo]);
199    float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) vidx_hi]);
200    vl_lo = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1);
201    vl_hi = vld1_lane_f32(&xnn_table_exp2minus_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1);
202    const float32x4_t vl = vcombine_f32(vl_lo, vl_hi);
203
204    const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve));
205    vn = vsubq_f32(vn, vmagic_bias);
206    $if FMA:
207      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2);
208    $else:
209      float32x4_t vt = ${VMULADDQ_F32}(vz, vn, vln2_hi);
210      vt = ${VMULADDQ_F32}(vt, vn, vln2_lo);
211
212    float32x4_t vp = vmulq_f32(vt, vc2);
213    vp = ${VMULSUBQ_F32}(vt, vp, vt);
214
215    const float32x4_t vy = ${VMULSUBQ_F32}(vs, vs, vp);
216    const float32x4_t vd = vaddq_f32(vy, vone);
217
218    $if DIV_ALGO == "div":
219      float32x4_t vf = vdivq_f32(vy, vd);
220    $else:
221      float32x4_t vr = vrecpeq_f32(vd);
222      $if DIV_ALGO == "nr2fma":
223        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
224      $else:
225        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
226      $if DIV_ALGO == "nr2recps":
227        vr = vmulq_f32(vr, vrecpsq_f32(vr, vd));
228      $else:
229        vr = vfmaq_f32(vr, vr, vfmsq_f32(vone, vr, vd));
230
231      float32x4_t vf = vmulq_f32(vy, vr);
232    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcagtq_f32(vx, vdenorm_cutoff)));
233    const uint32x4_t vm = vcltq_f32(vx, vmovq_n_f32(0.0f));
234    vf = vbslq_f32(vm, vf, vsubq_f32(vone, vf));
235
236    float32x2_t vf_lo = vget_low_f32(vf);
237    if (n & (2 * sizeof(float))) {
238      vst1_f32(y, vf_lo); y += 2;
239      vf_lo = vget_high_f32(vf);
240    }
241    if (n & (1 * sizeof(float))) {
242      vst1_lane_f32(y, vf_lo, 0);
243    }
244  }
245}
246