• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ELEMENTS_TILE % 4 == 0
7$assert ELEMENTS_TILE >= 4
8$SIMD_TILE = ELEMENTS_TILE // 4
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/common.h>
16#include <xnnpack/raddstoreexpminusmax.h>
17
18
19extern XNN_INTERNAL const float xnn_table_exp2_k_over_64[64];
20
21void xnn_f32_raddstoreexpminusmax_ukernel__${"neonfma" if FMA else "neon"}_lut64_p2_x${ELEMENTS_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
22    size_t elements,
23    const float* input,
24    float* output,
25    float* sum,
26    float max) XNN_DISABLE_TSAN
27{
28  assert(elements % sizeof(float) == 0);
29
30  const float32x4_t vmagic_bias = vmovq_n_f32(0x1.800000p23f);
31  // The smallest x for which expf(x) is normalized.
32  const float32x4_t vdenorm_cutoff = vmovq_n_f32(-0x1.5D589Ep6f);
33  const float32x4_t vlog2e_x64  = vmovq_n_f32(0x1.715476p6f);
34  $if FMA:
35    const float32x4_t vminus_ln2_o64_hi = vmovq_n_f32(-0x1.62e43p-7f);
36    const float32x4_t vminus_ln2_o64_lo = vmovq_n_f32(0x1.05c61p-35f);
37  $else:
38    // Last 13 bits are zeroes
39    const float32x4_t vminus_ln2_o64_hi = vmovq_n_f32(-0x1.630000p-7f);
40    const float32x4_t vminus_ln2_o64_lo = vmovq_n_f32(0x1.BD0106p-19f);
41
42  const float32x4_t vc2 = vmovq_n_f32(0x1.FFFF0Ap-2f);
43
44  const int32x4_t vindex_mask = vmovq_n_s32(INT32_C(0x3F));
45
46  const float32x4_t vi_max = vdupq_n_f32(max);
47
48  $if ELEMENTS_TILE > 4:
49    $for K in range(ACCUMULATORS):
50      float32x4_t vacc${K} = vmovq_n_f32(0.0f);
51    for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) {
52      // Load ${ELEMENTS_TILE} (${SIMD_TILE}x4) inputs at a time.
53      $for N in range(0, ELEMENTS_TILE, 4):
54        const float32x4_t vi${ABC[N:N+4]} = vld1q_f32(input); input += 4;
55
56      // Subtract maximum input x := i - i_max. This implies x <= 0.
57      $for N in range(0, ELEMENTS_TILE, 4):
58        const float32x4_t vx${ABC[N:N+4]} = vsubq_f32(vi${ABC[N:N+4]}, vi_max);
59
60      // Compute reduced argument n := round(x * 64 / log(2)).
61      // We do it by adding a large number (magic bias), which cause rounding of the result to an integer, then subtracing
62      // the large number back. The first addition is combined with multiplication by log2e into a single FMA instruction.
63      // The trick with adding large number is valid only within certain bounds (|x * 64 / log(2)| <= 2**22, i.e.
64      // |x| <= 0x1.62E43p+15 = 45426.09375), but that is acceptable, because inputs outside of [-87.336540, 0.0]
65      // result in denormalized or underflown expf(x). We fixup the result for such inputs at the very end of the
66      // algorithm.
67      $for N in range(0, ELEMENTS_TILE, 4):
68        float32x4_t vn${ABC[N:N+4]} = ${VMULADDQ_F32}(vmagic_bias, vx${ABC[N:N+4]}, vlog2e_x64);
69
70      // Create a floating-point number s (scale) such that s := 2**(n / 64) for such inputs that expf(x) is normalized,
71      // i.e. -87.33642 <= x <= 0.0. As n has 6 fractional bits, we split s == 2**(n / 64) = 2**e * 2**(n / 64 - e), where
72      // e := int(n / 64). We create s in two steps:
73      // 1. Fetch 2**(n / 64 - e) = 2**(n % 64) from the table using the 6 low bits of n, as integer. Note that the
74      //    fetched values are in the [1.0, 2.0) range, i.e. their floating-point exponent is 0.
75      // 2. Adjust fecthed value by addition of e to its floating-point exponent. The result is always a normalized
76      //    number, because for -87.33642 <= x <= 0.0 (inputs for which expf(x) is normalized) we have -126 <= e <= 0,
77      //    and thus the adjusted exponent is not lower than -126.
78      //
79      // Extract e from bits 6:14 of n and shift it into bits 23:31 (position of floating-point exponent).
80      $for N in range(0, ELEMENTS_TILE, 4):
81        const int32x4_t ve${ABC[N:N+4]} = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), vmovq_n_s32(INT32_C(0x3F))), 17);
82
83      // Use bits 0:6 bits of n, as integer, as an index for table lookup of l := 2**(n % 64).
84      $for N in range(0, ELEMENTS_TILE, 4):
85        const uint64x2_t vidx${ABC[N:N+4]} = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), vindex_mask));
86        const uint64_t vidx${ABC[N:N+2]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 0);
87        const uint64_t vidx${ABC[N+2:N+4]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 1);
88
89      $for N in range(0, ELEMENTS_TILE, 4):
90        float32x2_t vl${ABC[N:N+2]} = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx${ABC[N:N+2]}]);
91        float32x2_t vl${ABC[N+2:N+4]} = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx${ABC[N+2:N+4]}]);
92
93      $for N in range(0, ELEMENTS_TILE, 4):
94        vl${ABC[N:N+2]} = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx${ABC[N:N+2]} >> 32)], vl${ABC[N:N+2]}, 1);
95        vl${ABC[N+2:N+4]} = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx${ABC[N+2:N+4]} >> 32)], vl${ABC[N+2:N+4]}, 1);
96        const float32x4_t vl${ABC[N:N+4]} = vcombine_f32(vl${ABC[N:N+2]}, vl${ABC[N+2:N+4]});
97
98      // Adjust exponent of the value l fetched from the table to get the final s value.
99      $for N in range(0, ELEMENTS_TILE, 4):
100        const float32x4_t vs${ABC[N:N+4]} = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl${ABC[N:N+4]}), ve${ABC[N:N+4]}));
101
102      // Subtract the large number back to get final n := round(x * 64 / log(2)) as a floating-point number.
103      $for N in range(0, ELEMENTS_TILE, 4):
104        vn${ABC[N:N+4]} = vsubq_f32(vn${ABC[N:N+4]}, vmagic_bias);
105
106      // Compute reduced argument t := x - n * log(2) / 64.
107      // Use Cody-Waite range reduction method (note the two constants representing log(2) / 64) to improve accuracy.
108      $for N in range(0, ELEMENTS_TILE, 4):
109        float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vx${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_o64_hi);
110
111      $for N in range(0, ELEMENTS_TILE, 4):
112        vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_o64_lo);
113
114      // Compute degree-2 polynomial approximation for exp(t) on [-log(2)/128, log(2)/128].
115      $for N in range(0, ELEMENTS_TILE, 4):
116        float32x4_t vp${ABC[N:N+4]} = vmulq_f32(vt${ABC[N:N+4]}, vc2);
117
118      $for N in range(0, ELEMENTS_TILE, 4):
119        vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vt${ABC[N:N+4]}, vp${ABC[N:N+4]});
120
121      // Reconstruct the final f value:
122      //   f = s * (1 + t * (1 + t * c2))
123      //     = s * (1 + t + t * (t * c2))
124      //     = s + s * (t + t * (t * c2))
125      //     = s + s * p
126      $for N in range(0, ELEMENTS_TILE, 4):
127        float32x4_t vf${ABC[N:N+4]} = ${VMULADDQ_F32}(vs${ABC[N:N+4]}, vs${ABC[N:N+4]}, vp${ABC[N:N+4]});
128
129      // For inputs below denormal cutoff, replace output with +0.0f.
130      // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
131      $for N in range(0, ELEMENTS_TILE, 4):
132        vf${ABC[N:N+4]} = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf${ABC[N:N+4]}), vcltq_f32(vx${ABC[N:N+4]}, vdenorm_cutoff)));
133
134      // Store ${ELEMENTS_TILE} (${SIMD_TILE}x4) outputs at a time.
135      $for N in range(0, ELEMENTS_TILE, 4):
136        vst1q_f32(output, vf${ABC[N:N+4]}); output += 4;
137
138      // Accumulate computed exponents.
139      $for N in range(0, ELEMENTS_TILE, 4):
140        vacc${N % ACCUMULATORS} = vaddq_f32(vacc${N % ACCUMULATORS}, vf${ABC[N:N+4]});
141    }
142    $if ACCUMULATORS > 1:
143      // Add up all accumulators to vacc0
144      $ACC_SLICE = 1
145      $while ACC_SLICE < ACCUMULATORS:
146        $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
147          $if A + ACC_SLICE < ACCUMULATORS:
148            vacc${A} = vaddq_f32(vacc${A}, vacc${A + ACC_SLICE});
149        $ACC_SLICE *= 2
150
151    float32x4_t vacc = vacc0;
152  $else:
153    float32x4_t vacc = vmovq_n_f32(0.0f);
154  for (; elements >= 4 * sizeof(float); elements -= 4 * sizeof(float)) {
155    // Load 4 inputs at a time.
156    const float32x4_t vi = vld1q_f32(input); input += 4;
157
158    // Subtract maximum input x := i - i_max. This implies x <= 0.
159    const float32x4_t vx = vsubq_f32(vi, vi_max);
160
161    // Compute reduced argument n := round(x * 64 / log(2)).
162    // We do it by adding a large number (magic bias), which cause rounding of the result to an integer, then subtracing
163    // the large number back. The first addition is combined with multiplication by log2e into a single FMA instruction.
164    // The trick with adding large number is valid only within certain bounds (|x * 64 / log(2)| <= 2**22, i.e.
165    // |x| <= 0x1.62E43p+15 = 45426.09375), but that is acceptable, because inputs outside of [-87.336540, 0.0]
166    // result in denormalized or underflown expf(x). We fixup the result for such inputs at the very end of the
167    // algorithm.
168    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vx, vlog2e_x64);
169
170    // Create a floating-point number s (scale) such that s := 2**(n / 64) for such inputs that expf(x) is normalized,
171    // i.e. -87.33642 <= x <= 0.0. As n has 6 fractional bits, we split s == 2**(n / 64) = 2**e * 2**(n / 64 - e), where
172    // e := int(n / 64). We create s in two steps:
173    // 1. Fetch 2**(n / 64 - e) = 2**(n % 64) from the table using the 6 low bits of n, as integer. Note that the
174    //    fetched values are in the [1.0, 2.0) range, i.e. their floating-point exponent is 0.
175    // 2. Adjust fecthed value by addition of e to its floating-point exponent. The result is always a normalized
176    //    number, because for -87.33642 <= x <= 0.0 (inputs for which expf(x) is normalized) we have -126 <= e <= 0,
177    //    and thus the adjusted exponent is not lower than -126.
178    //
179    // Extract e from bits 6:14 of n and shift it into bits 23:31 (position of floating-point exponent).
180    const int32x4_t ve = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn), vmovq_n_s32(INT32_C(0x3F))), 17);
181
182    // Use bits 0:6 bits of n, as integer, as an index for table lookup of l := 2**(n % 64).
183    const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask));
184    const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0);
185    const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1);
186    float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_lo]);
187    float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_hi]);
188    vl_lo = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1);
189    vl_hi = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1);
190    const float32x4_t vl = vcombine_f32(vl_lo, vl_hi);
191    // Adjust exponent of the value l fetched from the table to get the final s value.
192    const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve));
193
194    // Subtract the large number back to get final n := round(x * 64 / log(2)) as a floating-point number.
195    vn = vsubq_f32(vn, vmagic_bias);
196
197    // Compute reduced argument t := x - n * log(2) / 64.
198    // Use Cody-Waite range reduction method (note the two constants representing log(2) / 64) to improve accuracy.
199    float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2_o64_hi);
200    vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_o64_lo);
201
202    // Compute degree-2 polynomial approximation for exp(t) on [-log(2)/128, log(2)/128].
203    float32x4_t vp = vmulq_f32(vt, vc2);
204    vp = ${VMULADDQ_F32}(vt, vt, vp);
205
206    // Reconstruct the final f value:
207    //   f = s * (1 + t * (1 + t * c2))
208    //     = s * (1 + t + t * (t * c2))
209    //     = s + s * (t + t * (t * c2))
210    //     = s + s * p
211    float32x4_t vf = ${VMULADDQ_F32}(vs, vs, vp);
212
213    // For inputs below denormal cutoff, replace output with +0.0f.
214    // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
215    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcltq_f32(vx, vdenorm_cutoff)));
216
217    // Store 4 outputs at a time.
218    vst1q_f32(output, vf); output += 4;
219
220    // Accumulate computed exponents.
221    vacc = vaddq_f32(vacc, vf);
222  }
223#if XNN_ARCH_ARM64
224  float vacc_lo = vaddvq_f32(vacc);
225#else
226  float32x2_t vacc_lo = vadd_f32(vget_high_f32(vacc), vget_low_f32(vacc));
227#endif
228  if (elements != 0) {
229    assert(elements >= 1 * sizeof(float));
230    assert(elements <= 3 * sizeof(float));
231    // Load 4 inputs at a time.
232    const float32x4_t vi = vld1q_f32(input); input += 4;
233
234    // Subtract maximum input x := i - i_max. This implies x <= 0.
235    const float32x4_t vx = vsubq_f32(vi, vi_max);
236
237    // Compute reduced argument n := round(x * 64 / log(2)).
238    // We do it by adding a large number (magic bias), which cause rounding of the result to an integer, then subtracing
239    // the large number back. The first addition is combined with multiplication by log2e into a single FMA instruction.
240    // The trick with adding large number is valid only within certain bounds (|x * 64 / log(2)| <= 2**22, i.e.
241    // |x| <= 0x1.62E43p+15 = 45426.09375), but that is acceptable, because inputs outside of [-87.336540, 0.0]
242    // result in denormalized or underflown expf(x). We fixup the result for such inputs at the very end of the
243    // algorithm.
244    float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vx, vlog2e_x64);
245
246    // Create a floating-point number s (scale) such that s := 2**(n / 64) for such inputs that expf(x) is normalized,
247    // i.e. -87.33642 <= x <= 0.0. As n has 6 fractional bits, we split s == 2**(n / 64) = 2**e * 2**(n / 64 - e), where
248    // e := int(n / 64). We create s in two steps:
249    // 1. Fetch 2**(n / 64 - e) = 2**(n % 64) from the table using the 6 low bits of n, as integer. Note that the
250    //    fetched values are in the [1.0, 2.0) range, i.e. their floating-point exponent is 0.
251    // 2. Adjust fecthed value by addition of e to its floating-point exponent. The result is always a normalized
252    //    number, because for -87.33642 <= x <= 0.0 (inputs for which expf(x) is normalized) we have -126 <= e <= 0,
253    //    and thus the adjusted exponent is not lower than -126.
254    //
255    // Extract e from bits 6:14 of n and shift it into bits 23:31 (position of floating-point exponent).
256    const int32x4_t ve = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn), vmovq_n_s32(INT32_C(0x3F))), 17);
257
258    // Use bits 0:6 bits of n, as integer, as an index for table lookup of l := 2**(n % 64).
259    const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask));
260    const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0);
261    const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1);
262    float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_lo]);
263    float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_hi]);
264    vl_lo = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1);
265    vl_hi = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1);
266    const float32x4_t vl = vcombine_f32(vl_lo, vl_hi);
267    // Adjust exponent of the value l fetched from the table to get the final s value.
268    const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve));
269
270    // Subtract the large number back to get final n := round(x * 64 / log(2)) as a floating-point number.
271    vn = vsubq_f32(vn, vmagic_bias);
272
273    // Compute reduced argument t := x - n * log(2) / 64.
274    // Use Cody-Waite range reduction method (note the two constants representing log(2) / 64) to improve accuracy.
275    float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2_o64_hi);
276    vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_o64_lo);
277
278    // Compute degree-2 polynomial approximation for exp(t) on [-log(2)/128, log(2)/128].
279    float32x4_t vp = vmulq_f32(vt, vc2);
280    vp = ${VMULADDQ_F32}(vt, vt, vp);
281
282    // Reconstruct the final f value:
283    //   f = s * (1 + t * (1 + t * c2))
284    //     = s * (1 + t + t * (t * c2))
285    //     = s + s * (t + t * (t * c2))
286    //     = s + s * p
287    float32x4_t vf = ${VMULADDQ_F32}(vs, vs, vp);
288
289    // For inputs below denormal cutoff, replace output with +0.0f.
290    // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
291    vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcltq_f32(vx, vdenorm_cutoff)));
292
293    float32x2_t vf_lo = vget_low_f32(vf);
294    if (elements & (2 * sizeof(float))) {
295      // Store 2 outputs at a time.
296      vst1_f32(output, vf_lo); output += 2;
297
298      // Accumulate 2 computed exponents.
299      #if XNN_ARCH_ARM64
300        vacc_lo += vaddv_f32(vf_lo);
301      #else
302        vacc_lo = vadd_f32(vacc_lo, vf_lo);
303      #endif
304
305      vf_lo = vget_high_f32(vf);
306    }
307    if (elements & (1 * sizeof(float))) {
308      // Store 1 output at a time.
309      vst1_lane_f32(output, vf_lo, 0);
310
311      // Accumulate 1 computed exponent.
312      #if XNN_ARCH_ARM64
313        vacc_lo += vget_lane_f32(vf_lo, 0);
314      #else
315        vacc_lo = vadd_f32(vacc_lo, vreinterpret_f32_u64(vshl_n_u64(vreinterpret_u64_f32(vf_lo), 32)));
316      #endif
317    }
318  }
319  // Reduce 4 elements in the SIMD register
320#if XNN_ARCH_ARM64
321  *sum = vacc_lo;
322#else
323  vst1_lane_f32(sum, vpadd_f32(vacc_lo, vacc_lo), 0);
324#endif
325}
326