• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <arm_neon.h>
9 
10 #include <xnnpack/common.h>
11 #include <xnnpack/vunary.h>
12 
13 
xnn_math_f32_sigmoid__neon_frac_p9_p10_nr1recps(size_t n,const float * input,float * output)14 void xnn_math_f32_sigmoid__neon_frac_p9_p10_nr1recps(
15     size_t n,
16     const float* input,
17     float* output)
18 {
19   assert(n % (4 * sizeof(float)) == 0);
20 
21   const float32x4_t vhalf = vmovq_n_f32(0.5f);
22 
23   // The coefficients of the numerator polynomial (odd).
24   const float32x4_t valpha_1 = vmovq_n_f32(2.48287947061529e-01);
25   const float32x4_t valpha_3 = vmovq_n_f32(8.51377133304701e-03);
26   const float32x4_t valpha_5 = vmovq_n_f32(6.08574864600143e-05);
27   const float32x4_t valpha_7 = vmovq_n_f32(1.15627324459942e-07);
28   const float32x4_t valpha_9 = vmovq_n_f32(4.37031012579801e-11);
29 
30   // The coefficients of the denominator polynomial (even).
31   const float32x4_t vbeta_0 =  vmovq_n_f32(9.93151921023180e-01);
32   const float32x4_t vbeta_2 =  vmovq_n_f32(1.16817656904453e-01);
33   const float32x4_t vbeta_4 =  vmovq_n_f32(1.70198817374094e-03);
34   const float32x4_t vbeta_6 =  vmovq_n_f32(6.29106785017040e-06);
35   const float32x4_t vbeta_8 =  vmovq_n_f32(5.76102136993427e-09);
36   const float32x4_t vbeta_10 = vmovq_n_f32(6.10247389755681e-13);
37 
38   // Sigmoid ~saturates outside of this range anyway.
39   const float32x4_t vsigmoid_maxinput = vdupq_n_f32(18.f);
40   const float32x4_t vsigmoid_mininput = vdupq_n_f32(-18.f);
41 
42   for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
43     float32x4_t vn = vld1q_f32(input); input += 4;
44 
45     vn = vminq_f32(vn, vsigmoid_maxinput);
46     vn = vmaxq_f32(vn, vsigmoid_mininput);
47 
48     const float32x4_t vn_sq = vmulq_f32(vn, vn);
49 
50     // Evaluate numerator polynomial
51     float32x4_t vnum = vmlaq_f32(valpha_7, vn_sq, valpha_9);
52 
53     vnum = vmlaq_f32(valpha_5, vn_sq, vnum);
54     vnum = vmlaq_f32(valpha_3, vn_sq, vnum);
55     vnum = vmlaq_f32(valpha_1, vn_sq, vnum);
56     vnum = vmulq_f32(vn, vnum);
57 
58     // Evaluate denominator polynomial
59 
60     float32x4_t vdenom = vmlaq_f32(vbeta_8, vn_sq, vbeta_10);
61     vdenom = vmlaq_f32(vbeta_6, vn_sq, vdenom);
62     vdenom = vmlaq_f32(vbeta_4, vn_sq, vdenom);
63     vdenom = vmlaq_f32(vbeta_2, vn_sq, vdenom);
64     vdenom = vmlaq_f32(vbeta_0, vn_sq, vdenom);
65 
66     // Do division, one NR iteration
67 
68     float32x4_t vrecp = vrecpeq_f32(vdenom);
69     vrecp = vmulq_f32(vrecp, vrecpsq_f32(vrecp, vdenom));
70 
71     const float32x4_t vsigmoid = vmlaq_f32(vhalf, vnum, vrecp);
72 
73     vst1q_f32(output, vsigmoid); output += 4;
74   }
75 }
76