1 // Auto-generated file. Do not edit!
2 // Template: src/f32-sigmoid/neon-frac-p9-p10-nr1recps.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/common.h>
15 #include <xnnpack/vunary.h>
16
17
xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x16(size_t n,const float * x,float * y,const void * params)18 void xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x16(
19 size_t n,
20 const float* x,
21 float* y,
22 const void* params)
23 {
24 assert(n % sizeof(float) == 0);
25
26 const float32x4_t vhalf = vmovq_n_f32(0.5f);
27
28 // The coefficients of the numerator polynomial (odd).
29 const float32x4_t valpha_1 = vmovq_n_f32(2.48287947061529e-01);
30 const float32x4_t valpha_3 = vmovq_n_f32(8.51377133304701e-03);
31 const float32x4_t valpha_5 = vmovq_n_f32(6.08574864600143e-05);
32 const float32x4_t valpha_7 = vmovq_n_f32(1.15627324459942e-07);
33 const float32x4_t valpha_9 = vmovq_n_f32(4.37031012579801e-11);
34
35 // The coefficients of the denominator polynomial (even).
36 const float32x4_t vbeta_0 = vmovq_n_f32(9.93151921023180e-01);
37 const float32x4_t vbeta_2 = vmovq_n_f32(1.16817656904453e-01);
38 const float32x4_t vbeta_4 = vmovq_n_f32(1.70198817374094e-03);
39 const float32x4_t vbeta_6 = vmovq_n_f32(6.29106785017040e-06);
40 const float32x4_t vbeta_8 = vmovq_n_f32(5.76102136993427e-09);
41 const float32x4_t vbeta_10 = vmovq_n_f32(6.10247389755681e-13);
42
43 // Sigmoid ~saturates outside of this range anyway.
44 const float32x4_t vsigmoid_maxinput = vdupq_n_f32(18.f);
45 const float32x4_t vsigmoid_mininput = vdupq_n_f32(-18.f);
46
47 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
48 float32x4_t vn0123 = vld1q_f32(x); x += 4;
49 float32x4_t vn4567 = vld1q_f32(x); x += 4;
50 float32x4_t vn89AB = vld1q_f32(x); x += 4;
51 float32x4_t vnCDEF = vld1q_f32(x); x += 4;
52
53 // restrict range to avoid overflow, output saturates outside this anyway
54 vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
55 vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
56 vn4567 = vminq_f32(vn4567, vsigmoid_maxinput);
57 vn4567 = vmaxq_f32(vn4567, vsigmoid_mininput);
58 vn89AB = vminq_f32(vn89AB, vsigmoid_maxinput);
59 vn89AB = vmaxq_f32(vn89AB, vsigmoid_mininput);
60 vnCDEF = vminq_f32(vnCDEF, vsigmoid_maxinput);
61 vnCDEF = vmaxq_f32(vnCDEF, vsigmoid_mininput);
62
63 // square the input
64 const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
65 const float32x4_t vn4567_sq = vmulq_f32(vn4567, vn4567);
66 const float32x4_t vn89AB_sq = vmulq_f32(vn89AB, vn89AB);
67 const float32x4_t vnCDEF_sq = vmulq_f32(vnCDEF, vnCDEF);
68
69 // Evaluate numerator polynomial
70 float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
71 float32x4_t vnum4567 = vmlaq_f32(valpha_7, vn4567_sq, valpha_9);
72 float32x4_t vnum89AB = vmlaq_f32(valpha_7, vn89AB_sq, valpha_9);
73 float32x4_t vnumCDEF = vmlaq_f32(valpha_7, vnCDEF_sq, valpha_9);
74
75 vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
76 vnum4567 = vmlaq_f32(valpha_5, vn4567_sq, vnum4567);
77 vnum89AB = vmlaq_f32(valpha_5, vn89AB_sq, vnum89AB);
78 vnumCDEF = vmlaq_f32(valpha_5, vnCDEF_sq, vnumCDEF);
79
80 vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
81 vnum4567 = vmlaq_f32(valpha_3, vn4567_sq, vnum4567);
82 vnum89AB = vmlaq_f32(valpha_3, vn89AB_sq, vnum89AB);
83 vnumCDEF = vmlaq_f32(valpha_3, vnCDEF_sq, vnumCDEF);
84
85 vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
86 vnum4567 = vmlaq_f32(valpha_1, vn4567_sq, vnum4567);
87 vnum89AB = vmlaq_f32(valpha_1, vn89AB_sq, vnum89AB);
88 vnumCDEF = vmlaq_f32(valpha_1, vnCDEF_sq, vnumCDEF);
89
90 vnum0123 = vmulq_f32(vn0123, vnum0123);
91 vnum4567 = vmulq_f32(vn4567, vnum4567);
92 vnum89AB = vmulq_f32(vn89AB, vnum89AB);
93 vnumCDEF = vmulq_f32(vnCDEF, vnumCDEF);
94
95 // Evaluate denominator polynomial
96 float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
97 float32x4_t vdenom4567 = vmlaq_f32(vbeta_8, vn4567_sq, vbeta_10);
98 float32x4_t vdenom89AB = vmlaq_f32(vbeta_8, vn89AB_sq, vbeta_10);
99 float32x4_t vdenomCDEF = vmlaq_f32(vbeta_8, vnCDEF_sq, vbeta_10);
100
101 vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
102 vdenom4567 = vmlaq_f32(vbeta_6, vn4567_sq, vdenom4567);
103 vdenom89AB = vmlaq_f32(vbeta_6, vn89AB_sq, vdenom89AB);
104 vdenomCDEF = vmlaq_f32(vbeta_6, vnCDEF_sq, vdenomCDEF);
105
106 vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
107 vdenom4567 = vmlaq_f32(vbeta_4, vn4567_sq, vdenom4567);
108 vdenom89AB = vmlaq_f32(vbeta_4, vn89AB_sq, vdenom89AB);
109 vdenomCDEF = vmlaq_f32(vbeta_4, vnCDEF_sq, vdenomCDEF);
110
111 vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
112 vdenom4567 = vmlaq_f32(vbeta_2, vn4567_sq, vdenom4567);
113 vdenom89AB = vmlaq_f32(vbeta_2, vn89AB_sq, vdenom89AB);
114 vdenomCDEF = vmlaq_f32(vbeta_2, vnCDEF_sq, vdenomCDEF);
115
116 vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
117 vdenom4567 = vmlaq_f32(vbeta_0, vn4567_sq, vdenom4567);
118 vdenom89AB = vmlaq_f32(vbeta_0, vn89AB_sq, vdenom89AB);
119 vdenomCDEF = vmlaq_f32(vbeta_0, vnCDEF_sq, vdenomCDEF);
120
121 // Do division 1. / denom
122 float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
123 float32x4_t vrecp4567 = vrecpeq_f32(vdenom4567);
124 float32x4_t vrecp89AB = vrecpeq_f32(vdenom89AB);
125 float32x4_t vrecpCDEF = vrecpeq_f32(vdenomCDEF);
126
127 // One NR iteration
128 vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
129 vrecp4567 = vmulq_f32(vrecp4567, vrecpsq_f32(vrecp4567, vdenom4567));
130 vrecp89AB = vmulq_f32(vrecp89AB, vrecpsq_f32(vrecp89AB, vdenom89AB));
131 vrecpCDEF = vmulq_f32(vrecpCDEF, vrecpsq_f32(vrecpCDEF, vdenomCDEF));
132
133
134 // .5 + num * (1. / denom)
135 const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
136 const float32x4_t vsigmoid4567 = vmlaq_f32(vhalf, vnum4567, vrecp4567);
137 const float32x4_t vsigmoid89AB = vmlaq_f32(vhalf, vnum89AB, vrecp89AB);
138 const float32x4_t vsigmoidCDEF = vmlaq_f32(vhalf, vnumCDEF, vrecpCDEF);
139
140
141 vst1q_f32(y, vsigmoid0123); y += 4;
142 vst1q_f32(y, vsigmoid4567); y += 4;
143 vst1q_f32(y, vsigmoid89AB); y += 4;
144 vst1q_f32(y, vsigmoidCDEF); y += 4;
145 }
146 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
147 float32x4_t vn0123 = vld1q_f32(x); x += 4;
148
149 vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
150 vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
151
152 const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
153
154 // Evaluate numerator polynomial
155 float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
156
157 vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
158 vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
159 vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
160 vnum0123 = vmulq_f32(vn0123, vnum0123);
161
162 // Evaluate denominator polynomial
163
164 float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
165 vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
166 vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
167 vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
168 vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
169
170 // Do division, one NR iteration
171
172 float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
173 vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
174
175 const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
176
177 vst1q_f32(y, vsigmoid0123); y += 4;
178 }
179 if XNN_UNLIKELY(n != 0) {
180 float32x4_t vn0123 = vld1q_f32(x);
181
182 vn0123 = vminq_f32(vn0123, vsigmoid_maxinput);
183 vn0123 = vmaxq_f32(vn0123, vsigmoid_mininput);
184
185 const float32x4_t vn0123_sq = vmulq_f32(vn0123, vn0123);
186
187 // Evaluate numerator polynomial
188 float32x4_t vnum0123 = vmlaq_f32(valpha_7, vn0123_sq, valpha_9);
189
190 vnum0123 = vmlaq_f32(valpha_5, vn0123_sq, vnum0123);
191 vnum0123 = vmlaq_f32(valpha_3, vn0123_sq, vnum0123);
192 vnum0123 = vmlaq_f32(valpha_1, vn0123_sq, vnum0123);
193 vnum0123 = vmulq_f32(vn0123, vnum0123);
194
195 // Evaluate denominator polynomial
196
197 float32x4_t vdenom0123 = vmlaq_f32(vbeta_8, vn0123_sq, vbeta_10);
198 vdenom0123 = vmlaq_f32(vbeta_6, vn0123_sq, vdenom0123);
199 vdenom0123 = vmlaq_f32(vbeta_4, vn0123_sq, vdenom0123);
200 vdenom0123 = vmlaq_f32(vbeta_2, vn0123_sq, vdenom0123);
201 vdenom0123 = vmlaq_f32(vbeta_0, vn0123_sq, vdenom0123);
202
203 // Do division, one NR iteration
204
205 float32x4_t vrecp0123 = vrecpeq_f32(vdenom0123);
206 vrecp0123 = vmulq_f32(vrecp0123, vrecpsq_f32(vrecp0123, vdenom0123));
207
208 const float32x4_t vsigmoid0123 = vmlaq_f32(vhalf, vnum0123, vrecp0123);
209
210 float32x2_t vf01 = vget_low_f32(vsigmoid0123);
211 if (n & (2 * sizeof(float))) {
212 vst1_f32(y, vf01); y += 2;
213 vf01 = vget_high_f32(vsigmoid0123);
214 }
215 if (n & (1 * sizeof(float))) {
216 vst1_lane_f32(y, vf01, 0);
217 }
218 }
219 }
220