• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-sigmoid/neon-frac-p9-p10-nr1recps.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/common.h>
15 #include <xnnpack/vunary.h>
16 
17 
xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x16(size_t n,const float * x,float * y,const void * params)18 void xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x16(
19     size_t n,
20     const float* x,
21     float* y,
22     const void* params)
23 {
24   assert(n % sizeof(float) == 0);
25 
26   const float32x4_t vhalf = vmovq_n_f32(0.5f);
27 
28   // The coefficients of the numerator polynomial (odd).
29   const float32x4_t valpha_1 = vmovq_n_f32(2.48287947061529e-01);
30   const float32x4_t valpha_3 = vmovq_n_f32(8.51377133304701e-03);
31   const float32x4_t valpha_5 = vmovq_n_f32(6.08574864600143e-05);
32   const float32x4_t valpha_7 = vmovq_n_f32(1.15627324459942e-07);
33   const float32x4_t valpha_9 = vmovq_n_f32(4.37031012579801e-11);
34 
35   // The coefficients of the denominator polynomial (even).
36   const float32x4_t vbeta_0 =  vmovq_n_f32(9.93151921023180e-01);
37   const float32x4_t vbeta_2 =  vmovq_n_f32(1.16817656904453e-01);
38   const float32x4_t vbeta_4 =  vmovq_n_f32(1.70198817374094e-03);
39   const float32x4_t vbeta_6 =  vmovq_n_f32(6.29106785017040e-06);
40   const float32x4_t vbeta_8 =  vmovq_n_f32(5.76102136993427e-09);
41   const float32x4_t vbeta_10 = vmovq_n_f32(6.10247389755681e-13);
42 
43   // Sigmoid ~saturates outside of this range anyway.
44   const float32x4_t vsigmoid_maxinput = vdupq_n_f32(18.f);
45   const float32x4_t vsigmoid_mininput = vdupq_n_f32(-18.f);
46 
47   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
48     float32x4_t vn0123 = vld1q_f32(x); x += 4;
49     float32x4_t vn4567 = vld1q_f32(x); x += 4;
50     float32x4_t vn89AB = vld1q_f32(x); x += 4;
51     float32x4_t vnCDEF = vld1q_f32(x); x += 4;
52 
53     // restrict range to avoid overflow, output saturates outside this anyway
54     vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
55     vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
56     vn4567 = vminq_f32(vn4567, vsigmoid_maxinput);
57     vn4567 = vmaxq_f32(vn4567, vsigmoid_mininput);
58     vn89AB = vminq_f32(vn89AB, vsigmoid_maxinput);
59     vn89AB = vmaxq_f32(vn89AB, vsigmoid_mininput);
60     vnCDEF = vminq_f32(vnCDEF, vsigmoid_maxinput);
61     vnCDEF = vmaxq_f32(vnCDEF, vsigmoid_mininput);
62 
63     // square the input
64     const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
65     const float32x4_t vn4567_sq = vmulq_f32(vn4567, vn4567);
66     const float32x4_t vn89AB_sq = vmulq_f32(vn89AB, vn89AB);
67     const float32x4_t vnCDEF_sq = vmulq_f32(vnCDEF, vnCDEF);
68 
69     // Evaluate numerator polynomial
70     float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
71     float32x4_t vnum4567 = vmlaq_f32(valpha_7, vn4567_sq, valpha_9);
72     float32x4_t vnum89AB = vmlaq_f32(valpha_7, vn89AB_sq, valpha_9);
73     float32x4_t vnumCDEF = vmlaq_f32(valpha_7, vnCDEF_sq, valpha_9);
74 
75     vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
76     vnum4567 = vmlaq_f32(valpha_5, vn4567_sq, vnum4567);
77     vnum89AB = vmlaq_f32(valpha_5, vn89AB_sq, vnum89AB);
78     vnumCDEF = vmlaq_f32(valpha_5, vnCDEF_sq, vnumCDEF);
79 
80     vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
81     vnum4567 = vmlaq_f32(valpha_3, vn4567_sq, vnum4567);
82     vnum89AB = vmlaq_f32(valpha_3, vn89AB_sq, vnum89AB);
83     vnumCDEF = vmlaq_f32(valpha_3, vnCDEF_sq, vnumCDEF);
84 
85     vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
86     vnum4567 = vmlaq_f32(valpha_1, vn4567_sq, vnum4567);
87     vnum89AB = vmlaq_f32(valpha_1, vn89AB_sq, vnum89AB);
88     vnumCDEF = vmlaq_f32(valpha_1, vnCDEF_sq, vnumCDEF);
89 
90     vnum0123 = vmulq_f32(vn0123, vnum0123);
91     vnum4567 = vmulq_f32(vn4567, vnum4567);
92     vnum89AB = vmulq_f32(vn89AB, vnum89AB);
93     vnumCDEF = vmulq_f32(vnCDEF, vnumCDEF);
94 
95     // Evaluate denominator polynomial
96     float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
97     float32x4_t vdenom4567 = vmlaq_f32(vbeta_8, vn4567_sq, vbeta_10);
98     float32x4_t vdenom89AB = vmlaq_f32(vbeta_8, vn89AB_sq, vbeta_10);
99     float32x4_t vdenomCDEF = vmlaq_f32(vbeta_8, vnCDEF_sq, vbeta_10);
100 
101     vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
102     vdenom4567 = vmlaq_f32(vbeta_6, vn4567_sq, vdenom4567);
103     vdenom89AB = vmlaq_f32(vbeta_6, vn89AB_sq, vdenom89AB);
104     vdenomCDEF = vmlaq_f32(vbeta_6, vnCDEF_sq, vdenomCDEF);
105 
106     vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
107     vdenom4567 = vmlaq_f32(vbeta_4, vn4567_sq, vdenom4567);
108     vdenom89AB = vmlaq_f32(vbeta_4, vn89AB_sq, vdenom89AB);
109     vdenomCDEF = vmlaq_f32(vbeta_4, vnCDEF_sq, vdenomCDEF);
110 
111     vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
112     vdenom4567 = vmlaq_f32(vbeta_2, vn4567_sq, vdenom4567);
113     vdenom89AB = vmlaq_f32(vbeta_2, vn89AB_sq, vdenom89AB);
114     vdenomCDEF = vmlaq_f32(vbeta_2, vnCDEF_sq, vdenomCDEF);
115 
116     vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
117     vdenom4567 = vmlaq_f32(vbeta_0, vn4567_sq, vdenom4567);
118     vdenom89AB = vmlaq_f32(vbeta_0, vn89AB_sq, vdenom89AB);
119     vdenomCDEF = vmlaq_f32(vbeta_0, vnCDEF_sq, vdenomCDEF);
120 
121     // Do division 1. / denom
122     float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
123     float32x4_t vrecp4567 = vrecpeq_f32(vdenom4567);
124     float32x4_t vrecp89AB = vrecpeq_f32(vdenom89AB);
125     float32x4_t vrecpCDEF = vrecpeq_f32(vdenomCDEF);
126 
127     // One NR iteration
128     vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
129     vrecp4567 = vmulq_f32(vrecp4567, vrecpsq_f32(vrecp4567, vdenom4567));
130     vrecp89AB = vmulq_f32(vrecp89AB, vrecpsq_f32(vrecp89AB, vdenom89AB));
131     vrecpCDEF = vmulq_f32(vrecpCDEF, vrecpsq_f32(vrecpCDEF, vdenomCDEF));
132 
133 
134     // .5 + num * (1. / denom)
135     const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
136     const float32x4_t vsigmoid4567 = vmlaq_f32(vhalf, vnum4567, vrecp4567);
137     const float32x4_t vsigmoid89AB = vmlaq_f32(vhalf, vnum89AB, vrecp89AB);
138     const float32x4_t vsigmoidCDEF = vmlaq_f32(vhalf, vnumCDEF, vrecpCDEF);
139 
140 
141     vst1q_f32(y, vsigmoid0123); y += 4;
142     vst1q_f32(y, vsigmoid4567); y += 4;
143     vst1q_f32(y, vsigmoid89AB); y += 4;
144     vst1q_f32(y, vsigmoidCDEF); y += 4;
145   }
146   for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
147     float32x4_t vn0123 = vld1q_f32(x); x += 4;
148 
149     vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
150     vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
151 
152     const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
153 
154     // Evaluate numerator polynomial
155     float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
156 
157     vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
158     vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
159     vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
160     vnum0123 = vmulq_f32(vn0123, vnum0123);
161 
162     // Evaluate denominator polynomial
163 
164     float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
165     vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
166     vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
167     vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
168     vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
169 
170     // Do division, one NR iteration
171 
172     float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
173     vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
174 
175     const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
176 
177     vst1q_f32(y, vsigmoid0123); y += 4;
178   }
179   if XNN_UNLIKELY(n != 0) {
180     float32x4_t vn0123 = vld1q_f32(x);
181 
182     vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
183     vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
184 
185     const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
186 
187     // Evaluate numerator polynomial
188     float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
189 
190     vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
191     vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
192     vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
193     vnum0123 = vmulq_f32(vn0123, vnum0123);
194 
195     // Evaluate denominator polynomial
196 
197     float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
198     vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
199     vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
200     vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
201     vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
202 
203     // Do division, one NR iteration
204 
205     float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
206     vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
207 
208     const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
209 
210     float32x2_t vf01 = vget_low_f32(vsigmoid0123);
211     if (n & (2 * sizeof(float))) {
212       vst1_f32(y, vf01); y += 2;
213       vf01 = vget_high_f32(vsigmoid0123);
214     }
215     if (n & (1 * sizeof(float))) {
216       vst1_lane_f32(y, vf01, 0);
217     }
218   }
219 }
220