1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_ukernel_12x2__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_12x2__neonfma(
18 uint32_t m,
19 uint32_t n,
20 const float*restrict a,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict c,
25 const union xnn_f32_output_params params[restrict static 1])
26 {
27 assert(m != 0);
28
29 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
30 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
31 size_t i = m;
32 while XNN_LIKELY(i >= 12) {
33 const float*restrict w = weights;
34 const int32_t* dmap = widx_dmap;
35 const uint32_t* nnzmap = nidx_nnzmap;
36 size_t j = n;
37 while (j >= 2) {
38 uint32_t nnz = *nnzmap++;
39 float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
40 float32x4_t vacc4567c0 = vacc0123c0;
41 float32x4_t vacc89ABc0 = vacc0123c0;
42 float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc4567c1 = vacc0123c1;
44 float32x4_t vacc89ABc1 = vacc0123c1;
45 if XNN_LIKELY(nnz != 0) {
46 do {
47 const intptr_t diff = *dmap++;
48 const float32x4_t va0123 = vld1q_f32(a);
49 const float32x4_t va4567 = vld1q_f32(a + 4);
50 const float32x4_t va89AB = vld1q_f32(a + 8);
51 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
52 const float32x2_t vb = vld1_f32(w); w += 2;
53
54 vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
55 vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
56 vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0);
57 vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
58 vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
59 vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1);
60 } while (--nnz != 0);
61 }
62 float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
63 float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
64 float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
65 float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
66 float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
67 float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
68
69 vout0123c0 = vmaxq_f32(vout0123c0, vmin);
70 vout4567c0 = vmaxq_f32(vout4567c0, vmin);
71 vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
72 vout0123c1 = vmaxq_f32(vout0123c1, vmin);
73 vout4567c1 = vmaxq_f32(vout4567c1, vmin);
74 vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
75
76 vst1q_f32(c + 0 * m + 0, vout0123c0);
77 vst1q_f32(c + 0 * m + 4, vout4567c0);
78 vst1q_f32(c + 0 * m + 8, vout89ABc0);
79 vst1q_f32(c + 1 * m + 0, vout0123c1);
80 vst1q_f32(c + 1 * m + 4, vout4567c1);
81 vst1q_f32(c + 1 * m + 8, vout89ABc1);
82 c += 2 * m;
83 j -= 2;
84 }
85
86 // clean up loop, fall back to nr=1
87 if XNN_UNLIKELY(j != 0) {
88 do {
89 uint32_t nnz = *nnzmap++;
90 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
91 float32x4_t vacc4567 = vacc0123;
92 float32x4_t vacc89AB = vacc0123;
93 if XNN_LIKELY(nnz != 0) {
94 do {
95 const intptr_t diff = *dmap++;
96 const float32x4_t va0123 = vld1q_f32(a);
97 const float32x4_t va4567 = vld1q_f32(a + 4);
98 const float32x4_t va89AB = vld1q_f32(a + 8);
99 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
100 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
101 vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
102 vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
103 vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
104 } while (--nnz != 0);
105 }
106 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
107 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
108 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
109
110 vout0123 = vmaxq_f32(vout0123, vmin);
111 vout4567 = vmaxq_f32(vout4567, vmin);
112 vout89AB = vmaxq_f32(vout89AB, vmin);
113
114 vst1q_f32(c + 0, vout0123);
115 vst1q_f32(c + 4, vout4567);
116 vst1q_f32(c + 8, vout89AB);
117 c += m;
118 j -= 1;
119 } while (j != 0);
120 }
121 c -= m * n;
122 c += 12;
123 a += 12;
124 i -= 12;
125 }
126 if XNN_UNLIKELY(i != 0) {
127 if (i & 8) {
128 const float*restrict w = weights;
129 const int32_t* dmap = widx_dmap;
130 const uint32_t* nnzmap = nidx_nnzmap;
131 size_t j = n;
132 while (j >= 2) {
133 uint32_t nnz = *nnzmap++;
134 float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
135 float32x4_t vacc4567c0 = vacc0123c0;
136 float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
137 float32x4_t vacc4567c1 = vacc0123c1;
138 if XNN_LIKELY(nnz != 0) {
139 do {
140 const intptr_t diff = *dmap++;
141 const float32x4_t va0123 = vld1q_f32(a);
142 const float32x4_t va4567 = vld1q_f32(a + 4);
143 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
144 const float32x2_t vb = vld1_f32(w); w += 2;
145
146 vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
147 vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
148 vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
149 vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
150 } while (--nnz != 0);
151 }
152 float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
153 float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
154 float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
155 float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
156
157 vout0123c0 = vmaxq_f32(vout0123c0, vmin);
158 vout4567c0 = vmaxq_f32(vout4567c0, vmin);
159 vout0123c1 = vmaxq_f32(vout0123c1, vmin);
160 vout4567c1 = vmaxq_f32(vout4567c1, vmin);
161
162 vst1q_f32(c + 0 * m + 0, vout0123c0);
163 vst1q_f32(c + 0 * m + 4, vout4567c0);
164 vst1q_f32(c + 1 * m + 0, vout0123c1);
165 vst1q_f32(c + 1 * m + 4, vout4567c1);
166 c += 2 * m;
167 j -= 2;
168 }
169
170 // clean up loop, fall back to nr=1
171 if XNN_UNLIKELY(j != 0) {
172 do {
173 uint32_t nnz = *nnzmap++;
174 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
175 float32x4_t vacc4567 = vacc0123;
176 if XNN_LIKELY(nnz != 0) {
177 do {
178 const intptr_t diff = *dmap++;
179 const float32x4_t va0123 = vld1q_f32(a);
180 const float32x4_t va4567 = vld1q_f32(a + 4);
181 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
182 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
183 vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
184 vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
185 } while (--nnz != 0);
186 }
187 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
188 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
189
190 vout0123 = vmaxq_f32(vout0123, vmin);
191 vout4567 = vmaxq_f32(vout4567, vmin);
192
193 vst1q_f32(c + 0, vout0123);
194 vst1q_f32(c + 4, vout4567);
195 c += m;
196 j -= 1;
197 } while (j != 0);
198 }
199 c -= m * n;
200 c += 8;
201 a += 8;
202 }
203 if (i & 4) {
204 const float*restrict w = weights;
205 const int32_t* dmap = widx_dmap;
206 const uint32_t* nnzmap = nidx_nnzmap;
207 size_t j = n;
208 while (j >= 2) {
209 uint32_t nnz = *nnzmap++;
210 float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
211 float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
212 if XNN_LIKELY(nnz != 0) {
213 do {
214 const intptr_t diff = *dmap++;
215 const float32x4_t va0123 = vld1q_f32(a);
216 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
217 const float32x2_t vb = vld1_f32(w); w += 2;
218
219 vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
220 vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
221 } while (--nnz != 0);
222 }
223 float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
224 float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
225
226 vout0123c0 = vmaxq_f32(vout0123c0, vmin);
227 vout0123c1 = vmaxq_f32(vout0123c1, vmin);
228
229 vst1q_f32(c + 0 * m + 0, vout0123c0);
230 vst1q_f32(c + 1 * m + 0, vout0123c1);
231 c += 2 * m;
232 j -= 2;
233 }
234
235 // clean up loop, fall back to nr=1
236 if XNN_UNLIKELY(j != 0) {
237 do {
238 uint32_t nnz = *nnzmap++;
239 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
240 if XNN_LIKELY(nnz != 0) {
241 do {
242 const intptr_t diff = *dmap++;
243 const float32x4_t va0123 = vld1q_f32(a);
244 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
245 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
246 vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
247 } while (--nnz != 0);
248 }
249 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
250
251 vout0123 = vmaxq_f32(vout0123, vmin);
252
253 vst1q_f32(c + 0, vout0123);
254 c += m;
255 j -= 1;
256 } while (j != 0);
257 }
258 c -= m * n;
259 c += 4;
260 a += 4;
261 }
262 if (i & 2) {
263 const float*restrict w = weights;
264 const int32_t* dmap = widx_dmap;
265 const uint32_t* nnzmap = nidx_nnzmap;
266 size_t j = n;
267 while (j >= 2) {
268 uint32_t nnz = *nnzmap++;
269 float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
270 float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
271 if XNN_LIKELY(nnz != 0) {
272 do {
273 const intptr_t diff = *dmap++;
274 const float32x2_t va01 = vld1_f32(a);
275 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
276 const float32x2_t vb = vld1_f32(w); w += 2;
277
278 vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
279 vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
280 } while (--nnz != 0);
281 }
282 float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
283 float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
284
285 vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
286 vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
287
288 vst1_f32(c + 0 * m + 0, vout01c0);
289 vst1_f32(c + 1 * m + 0, vout01c1);
290 c += 2 * m;
291 j -= 2;
292 }
293
294 // clean up loop, fall back to nr=1
295 if XNN_UNLIKELY(j != 0) {
296 do {
297 uint32_t nnz = *nnzmap++;
298 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
299 if XNN_LIKELY(nnz != 0) {
300 do {
301 const intptr_t diff = *dmap++;
302 const float32x2_t va01 = vld1_f32(a);
303 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
304 const float32x2_t vb = vld1_dup_f32(w); w += 1;
305 vacc01 = vfma_f32(vacc01, va01, vb);
306 } while (--nnz != 0);
307 }
308 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
309 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
310
311 vst1_f32(c, vout01);
312 c += m;
313 j -= 1;
314 } while (j != 0);
315 }
316 c -= m * n;
317 c += 2;
318 a += 2;
319 }
320 if (i & 1) {
321 const float*restrict w = weights;
322 const int32_t* dmap = widx_dmap;
323 const uint32_t* nnzmap = nidx_nnzmap;
324 size_t j = n;
325 while (j >= 2) {
326 uint32_t nnz = *nnzmap++;
327 float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
328 float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
329 if XNN_LIKELY(nnz != 0) {
330 do {
331 const intptr_t diff = *dmap++;
332 const float32x2_t va0 = vld1_dup_f32(a);
333 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
334 const float32x2_t vb = vld1_f32(w); w += 2;
335
336 vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
337 vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
338 } while (--nnz != 0);
339 }
340 float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
341 float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
342
343 vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
344 vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
345
346 vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
347 vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
348 c += 2 * m;
349 j -= 2;
350 }
351
352 // clean up loop, fall back to nr=1
353 if XNN_UNLIKELY(j != 0) {
354 do {
355 uint32_t nnz = *nnzmap++;
356 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
357 if XNN_LIKELY(nnz != 0) {
358 do {
359 const intptr_t diff = *dmap++;
360 const float32x2_t va0 = vld1_dup_f32(a);
361 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
362 const float32x2_t vb = vld1_dup_f32(w); w += 1;
363 vacc0 = vfma_f32(vacc0, va0, vb);
364 } while (--nnz != 0);
365 }
366 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
367 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
368
369 vst1_lane_f32(c, vout0, 1);
370 c += m;
371 j -= 1;
372 } while (j != 0);
373 }
374 c -= m * n;
375 c += 1;
376 a += 1;
377 }
378 }
379 }
380