1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_ukernel_8x4__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_8x4__neonfma(
18 uint32_t m,
19 uint32_t n,
20 const float*restrict a,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict c,
25 const union xnn_f32_output_params params[restrict static 1])
26 {
27 assert(m != 0);
28
29 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
30 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
31 size_t i = m;
32 while XNN_LIKELY(i >= 8) {
33 const float*restrict w = weights;
34 const int32_t* dmap = widx_dmap;
35 const uint32_t* nnzmap = nidx_nnzmap;
36 size_t j = n;
37 while (j >= 4) {
38 uint32_t nnz = *nnzmap++;
39 float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
40 float32x4_t vacc4567c0 = vacc0123c0;
41 float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
42 float32x4_t vacc4567c1 = vacc0123c1;
43 float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
44 float32x4_t vacc4567c2 = vacc0123c2;
45 float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
46 float32x4_t vacc4567c3 = vacc0123c3;
47 if XNN_LIKELY(nnz != 0) {
48 do {
49 const intptr_t diff = *dmap++;
50 const float32x4_t va0123 = vld1q_f32(a);
51 const float32x4_t va4567 = vld1q_f32(a + 4);
52 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
53 const float32x4_t vb = vld1q_f32(w); w += 4;
54
55 vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
56 vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
57 vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
58 vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
59 vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
60 vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
61 vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
62 vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
63 } while (--nnz != 0);
64 }
65 float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
66 float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
67 float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
68 float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
69 float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
70 float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
71 float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
72 float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
73
74 vout0123c0 = vmaxq_f32(vout0123c0, vmin);
75 vout4567c0 = vmaxq_f32(vout4567c0, vmin);
76 vout0123c1 = vmaxq_f32(vout0123c1, vmin);
77 vout4567c1 = vmaxq_f32(vout4567c1, vmin);
78 vout0123c2 = vmaxq_f32(vout0123c2, vmin);
79 vout4567c2 = vmaxq_f32(vout4567c2, vmin);
80 vout0123c3 = vmaxq_f32(vout0123c3, vmin);
81 vout4567c3 = vmaxq_f32(vout4567c3, vmin);
82
83 vst1q_f32(c + 0 * m + 0, vout0123c0);
84 vst1q_f32(c + 0 * m + 4, vout4567c0);
85 vst1q_f32(c + 1 * m + 0, vout0123c1);
86 vst1q_f32(c + 1 * m + 4, vout4567c1);
87 vst1q_f32(c + 2 * m + 0, vout0123c2);
88 vst1q_f32(c + 2 * m + 4, vout4567c2);
89 vst1q_f32(c + 3 * m + 0, vout0123c3);
90 vst1q_f32(c + 3 * m + 4, vout4567c3);
91 c += 4 * m;
92 j -= 4;
93 }
94
95 // clean up loop, fall back to nr=1
96 if XNN_UNLIKELY(j != 0) {
97 do {
98 uint32_t nnz = *nnzmap++;
99 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
100 float32x4_t vacc4567 = vacc0123;
101 if XNN_LIKELY(nnz != 0) {
102 do {
103 const intptr_t diff = *dmap++;
104 const float32x4_t va0123 = vld1q_f32(a);
105 const float32x4_t va4567 = vld1q_f32(a + 4);
106 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
107 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
108 vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
109 vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
110 } while (--nnz != 0);
111 }
112 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
113 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
114
115 vout0123 = vmaxq_f32(vout0123, vmin);
116 vout4567 = vmaxq_f32(vout4567, vmin);
117
118 vst1q_f32(c + 0, vout0123);
119 vst1q_f32(c + 4, vout4567);
120 c += m;
121 j -= 1;
122 } while (j != 0);
123 }
124 c -= m * n;
125 c += 8;
126 a += 8;
127 i -= 8;
128 }
129 if XNN_UNLIKELY(i != 0) {
130 if (i & 4) {
131 const float*restrict w = weights;
132 const int32_t* dmap = widx_dmap;
133 const uint32_t* nnzmap = nidx_nnzmap;
134 size_t j = n;
135 while (j >= 4) {
136 uint32_t nnz = *nnzmap++;
137 float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
138 float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
139 float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
140 float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
141 if XNN_LIKELY(nnz != 0) {
142 do {
143 const intptr_t diff = *dmap++;
144 const float32x4_t va0123 = vld1q_f32(a);
145 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
146 const float32x4_t vb = vld1q_f32(w); w += 4;
147
148 vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
149 vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
150 vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
151 vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
152 } while (--nnz != 0);
153 }
154 float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
155 float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
156 float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
157 float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
158
159 vout0123c0 = vmaxq_f32(vout0123c0, vmin);
160 vout0123c1 = vmaxq_f32(vout0123c1, vmin);
161 vout0123c2 = vmaxq_f32(vout0123c2, vmin);
162 vout0123c3 = vmaxq_f32(vout0123c3, vmin);
163
164 vst1q_f32(c + 0 * m + 0, vout0123c0);
165 vst1q_f32(c + 1 * m + 0, vout0123c1);
166 vst1q_f32(c + 2 * m + 0, vout0123c2);
167 vst1q_f32(c + 3 * m + 0, vout0123c3);
168 c += 4 * m;
169 j -= 4;
170 }
171
172 // clean up loop, fall back to nr=1
173 if XNN_UNLIKELY(j != 0) {
174 do {
175 uint32_t nnz = *nnzmap++;
176 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
177 if XNN_LIKELY(nnz != 0) {
178 do {
179 const intptr_t diff = *dmap++;
180 const float32x4_t va0123 = vld1q_f32(a);
181 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
182 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
183 vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
184 } while (--nnz != 0);
185 }
186 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
187
188 vout0123 = vmaxq_f32(vout0123, vmin);
189
190 vst1q_f32(c + 0, vout0123);
191 c += m;
192 j -= 1;
193 } while (j != 0);
194 }
195 c -= m * n;
196 c += 4;
197 a += 4;
198 }
199 if (i & 2) {
200 const float*restrict w = weights;
201 const int32_t* dmap = widx_dmap;
202 const uint32_t* nnzmap = nidx_nnzmap;
203 size_t j = n;
204 while (j >= 4) {
205 uint32_t nnz = *nnzmap++;
206 float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
207 float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
208 float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
209 float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
210 if XNN_LIKELY(nnz != 0) {
211 do {
212 const intptr_t diff = *dmap++;
213 const float32x2_t va01 = vld1_f32(a);
214 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
215 const float32x4_t vb = vld1q_f32(w); w += 4;
216
217 vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
218 vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
219 vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
220 vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
221 } while (--nnz != 0);
222 }
223 float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
224 float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
225 float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
226 float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
227
228 vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
229 vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
230 vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
231 vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
232
233 vst1_f32(c + 0 * m + 0, vout01c0);
234 vst1_f32(c + 1 * m + 0, vout01c1);
235 vst1_f32(c + 2 * m + 0, vout01c2);
236 vst1_f32(c + 3 * m + 0, vout01c3);
237 c += 4 * m;
238 j -= 4;
239 }
240
241 // clean up loop, fall back to nr=1
242 if XNN_UNLIKELY(j != 0) {
243 do {
244 uint32_t nnz = *nnzmap++;
245 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
246 if XNN_LIKELY(nnz != 0) {
247 do {
248 const intptr_t diff = *dmap++;
249 const float32x2_t va01 = vld1_f32(a);
250 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
251 const float32x2_t vb = vld1_dup_f32(w); w += 1;
252 vacc01 = vfma_f32(vacc01, va01, vb);
253 } while (--nnz != 0);
254 }
255 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
256 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
257
258 vst1_f32(c, vout01);
259 c += m;
260 j -= 1;
261 } while (j != 0);
262 }
263 c -= m * n;
264 c += 2;
265 a += 2;
266 }
267 if (i & 1) {
268 const float*restrict w = weights;
269 const int32_t* dmap = widx_dmap;
270 const uint32_t* nnzmap = nidx_nnzmap;
271 size_t j = n;
272 while (j >= 4) {
273 uint32_t nnz = *nnzmap++;
274 float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
275 float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
276 float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
277 float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
278 if XNN_LIKELY(nnz != 0) {
279 do {
280 const intptr_t diff = *dmap++;
281 const float32x2_t va0 = vld1_dup_f32(a);
282 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
283 const float32x4_t vb = vld1q_f32(w); w += 4;
284
285 vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
286 vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
287 vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
288 vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
289 } while (--nnz != 0);
290 }
291 float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
292 float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
293 float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
294 float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
295
296 vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
297 vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
298 vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
299 vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
300
301 vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
302 vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
303 vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
304 vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
305 c += 4 * m;
306 j -= 4;
307 }
308
309 // clean up loop, fall back to nr=1
310 if XNN_UNLIKELY(j != 0) {
311 do {
312 uint32_t nnz = *nnzmap++;
313 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
314 if XNN_LIKELY(nnz != 0) {
315 do {
316 const intptr_t diff = *dmap++;
317 const float32x2_t va0 = vld1_dup_f32(a);
318 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
319 const float32x2_t vb = vld1_dup_f32(w); w += 1;
320 vacc0 = vfma_f32(vacc0, va0, vb);
321 } while (--nnz != 0);
322 }
323 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
324 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
325
326 vst1_lane_f32(c, vout0, 1);
327 c += m;
328 j -= 1;
329 } while (j != 0);
330 }
331 c -= m * n;
332 c += 1;
333 a += 1;
334 }
335 }
336 }
337