1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_16x2__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_16x2__neonfma(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 16 * sizeof(float);
35 while XNN_LIKELY(mc >= 16 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 while (n >= 2) {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc4567n0 = vacc0123n0;
44 float32x4_t vacc89ABn0 = vacc0123n0;
45 float32x4_t vaccCDEFn0 = vacc0123n0;
46 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
47 float32x4_t vacc4567n1 = vacc0123n1;
48 float32x4_t vacc89ABn1 = vacc0123n1;
49 float32x4_t vaccCDEFn1 = vacc0123n1;
50 if XNN_LIKELY(nnz != 0) {
51 do {
52 const intptr_t diff = *dmap++;
53 const float32x4_t vi0123 = vld1q_f32(input);
54 const float32x4_t vi4567 = vld1q_f32(input + 4);
55 const float32x4_t vi89AB = vld1q_f32(input + 8);
56 const float32x4_t viCDEF = vld1q_f32(input + 12);
57 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
58 __builtin_prefetch(input + 16);
59 const float32x2_t vw = vld1_f32(w); w += 2;
60 __builtin_prefetch(w + 32);
61 vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
62 vacc4567n0 = vfmaq_lane_f32(vacc4567n0, vi4567, vw, 0);
63 vacc89ABn0 = vfmaq_lane_f32(vacc89ABn0, vi89AB, vw, 0);
64 vaccCDEFn0 = vfmaq_lane_f32(vaccCDEFn0, viCDEF, vw, 0);
65 vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
66 vacc4567n1 = vfmaq_lane_f32(vacc4567n1, vi4567, vw, 1);
67 vacc89ABn1 = vfmaq_lane_f32(vacc89ABn1, vi89AB, vw, 1);
68 vaccCDEFn1 = vfmaq_lane_f32(vaccCDEFn1, viCDEF, vw, 1);
69 } while (--nnz != 0);
70 }
71 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
72 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
73 float32x4_t vout89ABn0 = vminq_f32(vacc89ABn0, vmax);
74 float32x4_t voutCDEFn0 = vminq_f32(vaccCDEFn0, vmax);
75 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
76 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
77 float32x4_t vout89ABn1 = vminq_f32(vacc89ABn1, vmax);
78 float32x4_t voutCDEFn1 = vminq_f32(vaccCDEFn1, vmax);
79
80 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
81 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
82 vout89ABn0 = vmaxq_f32(vout89ABn0, vmin);
83 voutCDEFn0 = vmaxq_f32(voutCDEFn0, vmin);
84 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
85 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
86 vout89ABn1 = vmaxq_f32(vout89ABn1, vmin);
87 voutCDEFn1 = vmaxq_f32(voutCDEFn1, vmin);
88
89 vst1q_f32(output + 0, vout0123n0);
90 vst1q_f32(output + 4, vout4567n0);
91 vst1q_f32(output + 8, vout89ABn0);
92 vst1q_f32(output + 12, voutCDEFn0);
93 output = (float*restrict) ((uintptr_t) output + output_stride);
94 vst1q_f32(output + 0, vout0123n1);
95 vst1q_f32(output + 4, vout4567n1);
96 vst1q_f32(output + 8, vout89ABn1);
97 vst1q_f32(output + 12, voutCDEFn1);
98 output = (float*restrict) ((uintptr_t) output + output_stride);
99 n -= 2;
100 }
101
102 // clean up loop, fall back to nr=1
103 if XNN_UNLIKELY(n != 0) {
104 do {
105 uint32_t nnz = *nnzmap++;
106 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
107 float32x4_t vacc4567 = vacc0123;
108 float32x4_t vacc89AB = vacc0123;
109 float32x4_t vaccCDEF = vacc0123;
110 if XNN_LIKELY(nnz != 0) {
111 do {
112 const intptr_t diff = *dmap++;
113 const float32x4_t vi0123 = vld1q_f32(input);
114 const float32x4_t vi4567 = vld1q_f32(input + 4);
115 const float32x4_t vi89AB = vld1q_f32(input + 8);
116 const float32x4_t viCDEF = vld1q_f32(input + 12);
117 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
118 __builtin_prefetch(input + 16);
119 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
120 __builtin_prefetch(w + 32);
121 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
122 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
123 vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vw);
124 vaccCDEF = vfmaq_f32(vaccCDEF, viCDEF, vw);
125 } while (--nnz != 0);
126 }
127 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
128 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
129 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
130 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
131
132 vout0123 = vmaxq_f32(vout0123, vmin);
133 vout4567 = vmaxq_f32(vout4567, vmin);
134 vout89AB = vmaxq_f32(vout89AB, vmin);
135 voutCDEF = vmaxq_f32(voutCDEF, vmin);
136
137 vst1q_f32(output + 0, vout0123);
138 vst1q_f32(output + 4, vout4567);
139 vst1q_f32(output + 8, vout89AB);
140 vst1q_f32(output + 12, voutCDEF);
141 output = (float*restrict) ((uintptr_t) output + output_stride);
142 n -= 1;
143 } while (n != 0);
144 }
145 output = (float*restrict) ((uintptr_t) output - output_decrement);
146 input += 16;
147 mc -= 16 * sizeof(float);
148 }
149 if XNN_UNLIKELY(mc != 0) {
150 output_decrement += 8 * sizeof(float);
151 if (mc & (8 * sizeof(float))) {
152 const float*restrict w = weights;
153 const int32_t* dmap = widx_dmap;
154 const uint32_t* nnzmap = nidx_nnzmap;
155 size_t n = nc;
156 while (n >= 2) {
157 uint32_t nnz = *nnzmap++;
158 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
159 float32x4_t vacc4567n0 = vacc0123n0;
160 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
161 float32x4_t vacc4567n1 = vacc0123n1;
162 if XNN_LIKELY(nnz != 0) {
163 do {
164 const intptr_t diff = *dmap++;
165 const float32x4_t vi0123 = vld1q_f32(input);
166 const float32x4_t vi4567 = vld1q_f32(input + 4);
167 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
168 const float32x2_t vw = vld1_f32(w); w += 2;
169
170 vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
171 vacc4567n0 = vfmaq_lane_f32(vacc4567n0, vi4567, vw, 0);
172 vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
173 vacc4567n1 = vfmaq_lane_f32(vacc4567n1, vi4567, vw, 1);
174 } while (--nnz != 0);
175 }
176 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
177 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
178 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
179 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
180
181 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
182 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
183 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
184 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
185
186 vst1q_f32(output + 0, vout0123n0);
187 vst1q_f32(output + 4, vout4567n0);
188 output = (float*restrict) ((uintptr_t) output + output_stride);
189 vst1q_f32(output + 0, vout0123n1);
190 vst1q_f32(output + 4, vout4567n1);
191 output = (float*restrict) ((uintptr_t) output + output_stride);
192 n -= 2;
193 }
194
195 // clean up loop, fall back to nr=1
196 if XNN_UNLIKELY(n != 0) {
197 do {
198 uint32_t nnz = *nnzmap++;
199 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
200 float32x4_t vacc4567 = vacc0123;
201 if XNN_LIKELY(nnz != 0) {
202 do {
203 const intptr_t diff = *dmap++;
204 const float32x4_t vi0123 = vld1q_f32(input);
205 const float32x4_t vi4567 = vld1q_f32(input + 4);
206 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
207 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
208 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
209 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
210 } while (--nnz != 0);
211 }
212 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
213 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
214
215 vout0123 = vmaxq_f32(vout0123, vmin);
216 vout4567 = vmaxq_f32(vout4567, vmin);
217
218 vst1q_f32(output + 0, vout0123);
219 vst1q_f32(output + 4, vout4567);
220 output = (float*restrict) ((uintptr_t) output + output_stride);
221 n -= 1;
222 } while (n != 0);
223 }
224 output = (float*restrict) ((uintptr_t) output - output_decrement);
225 input += 8;
226 }
227 output_decrement += 4 * sizeof(float);
228 if (mc & (4 * sizeof(float))) {
229 const float*restrict w = weights;
230 const int32_t* dmap = widx_dmap;
231 const uint32_t* nnzmap = nidx_nnzmap;
232 size_t n = nc;
233 while (n >= 2) {
234 uint32_t nnz = *nnzmap++;
235 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
236 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
237 if XNN_LIKELY(nnz != 0) {
238 do {
239 const intptr_t diff = *dmap++;
240 const float32x4_t vi0123 = vld1q_f32(input);
241 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
242 const float32x2_t vw = vld1_f32(w); w += 2;
243
244 vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
245 vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
246 } while (--nnz != 0);
247 }
248 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
249 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
250
251 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
252 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
253
254 vst1q_f32(output + 0, vout0123n0);
255 output = (float*restrict) ((uintptr_t) output + output_stride);
256 vst1q_f32(output + 0, vout0123n1);
257 output = (float*restrict) ((uintptr_t) output + output_stride);
258 n -= 2;
259 }
260
261 // clean up loop, fall back to nr=1
262 if XNN_UNLIKELY(n != 0) {
263 do {
264 uint32_t nnz = *nnzmap++;
265 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
266 if XNN_LIKELY(nnz != 0) {
267 do {
268 const intptr_t diff = *dmap++;
269 const float32x4_t vi0123 = vld1q_f32(input);
270 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
271 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
272 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
273 } while (--nnz != 0);
274 }
275 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
276
277 vout0123 = vmaxq_f32(vout0123, vmin);
278
279 vst1q_f32(output + 0, vout0123);
280 output = (float*restrict) ((uintptr_t) output + output_stride);
281 n -= 1;
282 } while (n != 0);
283 }
284 output = (float*restrict) ((uintptr_t) output - output_decrement);
285 input += 4;
286 }
287 output_decrement += 2 * sizeof(float);
288 if (mc & (2 * sizeof(float))) {
289 const float*restrict w = weights;
290 const int32_t* dmap = widx_dmap;
291 const uint32_t* nnzmap = nidx_nnzmap;
292 size_t n = nc;
293 while (n >= 2) {
294 uint32_t nnz = *nnzmap++;
295 float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
296 float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
297 if XNN_LIKELY(nnz != 0) {
298 do {
299 const intptr_t diff = *dmap++;
300 const float32x2_t vi01 = vld1_f32(input);
301 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
302 const float32x2_t vw = vld1_f32(w); w += 2;
303
304 vacc01n0 = vfma_lane_f32(vacc01n0, vi01, vw, 0);
305 vacc01n1 = vfma_lane_f32(vacc01n1, vi01, vw, 1);
306 } while (--nnz != 0);
307 }
308 float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
309 float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
310
311 vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
312 vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
313
314 vst1_f32(output + 0, vout01n0);
315 output = (float*restrict) ((uintptr_t) output + output_stride);
316 vst1_f32(output + 0, vout01n1);
317 output = (float*restrict) ((uintptr_t) output + output_stride);
318 n -= 2;
319 }
320
321 // clean up loop, fall back to nr=1
322 if XNN_UNLIKELY(n != 0) {
323 do {
324 uint32_t nnz = *nnzmap++;
325 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
326 if XNN_LIKELY(nnz != 0) {
327 do {
328 const intptr_t diff = *dmap++;
329 const float32x2_t vi01 = vld1_f32(input);
330 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
331 const float32x2_t vw = vld1_dup_f32(w); w += 1;
332 vacc01 = vfma_f32(vacc01, vi01, vw);
333 } while (--nnz != 0);
334 }
335 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
336 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
337
338 vst1_f32(output, vout01);
339 output = (float*restrict) ((uintptr_t) output + output_stride);
340 n -= 1;
341 } while (n != 0);
342 }
343 output = (float*restrict) ((uintptr_t) output - output_decrement);
344 input += 2;
345 }
346 output_decrement += 1 * sizeof(float);
347 if (mc & (1 * sizeof(float))) {
348 const float*restrict w = weights;
349 const int32_t* dmap = widx_dmap;
350 const uint32_t* nnzmap = nidx_nnzmap;
351 size_t n = nc;
352 while (n >= 2) {
353 uint32_t nnz = *nnzmap++;
354 float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
355 float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
356 if XNN_LIKELY(nnz != 0) {
357 do {
358 const intptr_t diff = *dmap++;
359 const float32x2_t vi0 = vld1_dup_f32(input);
360 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
361 const float32x2_t vw = vld1_f32(w); w += 2;
362
363 vacc0n0 = vfma_lane_f32(vacc0n0, vi0, vw, 0);
364 vacc0n1 = vfma_lane_f32(vacc0n1, vi0, vw, 1);
365 } while (--nnz != 0);
366 }
367 float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
368 float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
369
370 vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
371 vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
372
373 vst1_lane_f32(output + 0, vout0n0, 0);
374 output = (float*restrict) ((uintptr_t) output + output_stride);
375 vst1_lane_f32(output + 0, vout0n1, 0);
376 output = (float*restrict) ((uintptr_t) output + output_stride);
377 n -= 2;
378 }
379
380 // clean up loop, fall back to nr=1
381 if XNN_UNLIKELY(n != 0) {
382 do {
383 uint32_t nnz = *nnzmap++;
384 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
385 if XNN_LIKELY(nnz != 0) {
386 do {
387 const intptr_t diff = *dmap++;
388 const float32x2_t vi0 = vld1_dup_f32(input);
389 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
390 const float32x2_t vw = vld1_dup_f32(w); w += 1;
391 vacc0 = vfma_f32(vacc0, vi0, vw);
392 } while (--nnz != 0);
393 }
394 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
395 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
396
397 vst1_lane_f32(output, vout0, 1);
398 output = (float*restrict) ((uintptr_t) output + output_stride);
399 n -= 1;
400 } while (n != 0);
401 }
402 output = (float*restrict) ((uintptr_t) output - output_decrement);
403 input += 1;
404 }
405 }
406 }
407