• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon-blocked.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_ukernel_16x2__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_16x2__neonfma(
18     uint32_t m,
19     uint32_t n,
20     const float*restrict a,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict c,
25     const union xnn_f32_output_params params[restrict static 1])
26 {
27   assert(m != 0);
28 
29   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
30   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
31   size_t i = m;
32   while XNN_LIKELY(i >= 16) {
33     const float*restrict w = weights;
34     const int32_t* dmap = widx_dmap;
35     const uint32_t* nnzmap = nidx_nnzmap;
36     size_t j = n;
37     while (j >= 2) {
38       uint32_t nnz = *nnzmap++;
39       float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
40       float32x4_t vacc4567c0 = vacc0123c0;
41       float32x4_t vacc89ABc0 = vacc0123c0;
42       float32x4_t vaccCDEFc0 = vacc0123c0;
43       float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
44       float32x4_t vacc4567c1 = vacc0123c1;
45       float32x4_t vacc89ABc1 = vacc0123c1;
46       float32x4_t vaccCDEFc1 = vacc0123c1;
47       if XNN_LIKELY(nnz != 0) {
48         do {
49           const intptr_t diff = *dmap++;
50           const float32x4_t va0123 = vld1q_f32(a);
51           const float32x4_t va4567 = vld1q_f32(a + 4);
52           const float32x4_t va89AB = vld1q_f32(a + 8);
53           const float32x4_t vaCDEF = vld1q_f32(a + 12);
54           __builtin_prefetch(a + 16);
55           a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
56           const float32x2_t vb = vld1_f32(w); w += 2;
57 
58           vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
59           vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
60           vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0);
61           vaccCDEFc0 = vfmaq_lane_f32(vaccCDEFc0, vaCDEF, vb, 0);
62           vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
63           vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
64           vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1);
65           vaccCDEFc1 = vfmaq_lane_f32(vaccCDEFc1, vaCDEF, vb, 1);
66         } while (--nnz != 0);
67       }
68       float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
69       float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
70       float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
71       float32x4_t voutCDEFc0 = vminq_f32(vaccCDEFc0, vmax);
72       float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
73       float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
74       float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
75       float32x4_t voutCDEFc1 = vminq_f32(vaccCDEFc1, vmax);
76 
77       vout0123c0 = vmaxq_f32(vout0123c0, vmin);
78       vout4567c0 = vmaxq_f32(vout4567c0, vmin);
79       vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
80       voutCDEFc0 = vmaxq_f32(voutCDEFc0, vmin);
81       vout0123c1 = vmaxq_f32(vout0123c1, vmin);
82       vout4567c1 = vmaxq_f32(vout4567c1, vmin);
83       vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
84       voutCDEFc1 = vmaxq_f32(voutCDEFc1, vmin);
85 
86       vst1q_f32(c + 0 * m + 0, vout0123c0);
87       vst1q_f32(c + 0 * m + 4, vout4567c0);
88       vst1q_f32(c + 0 * m + 8, vout89ABc0);
89       vst1q_f32(c + 0 * m + 12, voutCDEFc0);
90       vst1q_f32(c + 1 * m + 0, vout0123c1);
91       vst1q_f32(c + 1 * m + 4, vout4567c1);
92       vst1q_f32(c + 1 * m + 8, vout89ABc1);
93       vst1q_f32(c + 1 * m + 12, voutCDEFc1);
94       c += 2 * m;
95       j -= 2;
96     }
97 
98     // clean up loop, fall back to nr=1
99     if XNN_UNLIKELY(j != 0) {
100       do {
101         uint32_t nnz = *nnzmap++;
102         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
103         float32x4_t vacc4567 = vacc0123;
104         float32x4_t vacc89AB = vacc0123;
105         float32x4_t vaccCDEF = vacc0123;
106         if XNN_LIKELY(nnz != 0) {
107           do {
108             const intptr_t diff = *dmap++;
109             const float32x4_t va0123 = vld1q_f32(a);
110             const float32x4_t va4567 = vld1q_f32(a + 4);
111             const float32x4_t va89AB = vld1q_f32(a + 8);
112             const float32x4_t vaCDEF = vld1q_f32(a + 12);
113             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
114             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
115             vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
116             vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
117             vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
118             vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb);
119           } while (--nnz != 0);
120         }
121         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
122         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
123         float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
124         float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
125 
126         vout0123 = vmaxq_f32(vout0123, vmin);
127         vout4567 = vmaxq_f32(vout4567, vmin);
128         vout89AB = vmaxq_f32(vout89AB, vmin);
129         voutCDEF = vmaxq_f32(voutCDEF, vmin);
130 
131         vst1q_f32(c + 0, vout0123);
132         vst1q_f32(c + 4, vout4567);
133         vst1q_f32(c + 8, vout89AB);
134         vst1q_f32(c + 12, voutCDEF);
135         c += m;
136         j -= 1;
137       } while (j != 0);
138     }
139     c -= m * n;
140     c += 16;
141     a += 16;
142     i -= 16;
143   }
144   if XNN_UNLIKELY(i != 0) {
145     if (i & 8) {
146       const float*restrict w = weights;
147       const int32_t* dmap = widx_dmap;
148       const uint32_t* nnzmap = nidx_nnzmap;
149       size_t j = n;
150       while (j >= 2) {
151         uint32_t nnz = *nnzmap++;
152         float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
153         float32x4_t vacc4567c0 = vacc0123c0;
154         float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
155         float32x4_t vacc4567c1 = vacc0123c1;
156         if XNN_LIKELY(nnz != 0) {
157           do {
158             const intptr_t diff = *dmap++;
159             const float32x4_t va0123 = vld1q_f32(a);
160             const float32x4_t va4567 = vld1q_f32(a + 4);
161             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
162             const float32x2_t vb = vld1_f32(w); w += 2;
163 
164             vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
165             vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
166             vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
167             vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
168           } while (--nnz != 0);
169         }
170         float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
171         float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
172         float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
173         float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
174 
175         vout0123c0 = vmaxq_f32(vout0123c0, vmin);
176         vout4567c0 = vmaxq_f32(vout4567c0, vmin);
177         vout0123c1 = vmaxq_f32(vout0123c1, vmin);
178         vout4567c1 = vmaxq_f32(vout4567c1, vmin);
179 
180         vst1q_f32(c + 0 * m + 0, vout0123c0);
181         vst1q_f32(c + 0 * m + 4, vout4567c0);
182         vst1q_f32(c + 1 * m + 0, vout0123c1);
183         vst1q_f32(c + 1 * m + 4, vout4567c1);
184         c += 2 * m;
185         j -= 2;
186       }
187 
188       // clean up loop, fall back to nr=1
189       if XNN_UNLIKELY(j != 0) {
190         do {
191           uint32_t nnz = *nnzmap++;
192           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
193           float32x4_t vacc4567 = vacc0123;
194           if XNN_LIKELY(nnz != 0) {
195             do {
196               const intptr_t diff = *dmap++;
197               const float32x4_t va0123 = vld1q_f32(a);
198               const float32x4_t va4567 = vld1q_f32(a + 4);
199               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
200               const float32x4_t vb = vld1q_dup_f32(w); w += 1;
201               vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
202               vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
203             } while (--nnz != 0);
204           }
205           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
206           float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
207 
208           vout0123 = vmaxq_f32(vout0123, vmin);
209           vout4567 = vmaxq_f32(vout4567, vmin);
210 
211           vst1q_f32(c + 0, vout0123);
212           vst1q_f32(c + 4, vout4567);
213           c += m;
214           j -= 1;
215         } while (j != 0);
216       }
217       c -= m * n;
218       c += 8;
219       a += 8;
220     }
221     if (i & 4) {
222       const float*restrict w = weights;
223       const int32_t* dmap = widx_dmap;
224       const uint32_t* nnzmap = nidx_nnzmap;
225       size_t j = n;
226       while (j >= 2) {
227         uint32_t nnz = *nnzmap++;
228         float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
229         float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
230         if XNN_LIKELY(nnz != 0) {
231           do {
232             const intptr_t diff = *dmap++;
233             const float32x4_t va0123 = vld1q_f32(a);
234             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
235             const float32x2_t vb = vld1_f32(w); w += 2;
236 
237             vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
238             vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
239           } while (--nnz != 0);
240         }
241         float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
242         float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
243 
244         vout0123c0 = vmaxq_f32(vout0123c0, vmin);
245         vout0123c1 = vmaxq_f32(vout0123c1, vmin);
246 
247         vst1q_f32(c + 0 * m + 0, vout0123c0);
248         vst1q_f32(c + 1 * m + 0, vout0123c1);
249         c += 2 * m;
250         j -= 2;
251       }
252 
253       // clean up loop, fall back to nr=1
254       if XNN_UNLIKELY(j != 0) {
255         do {
256           uint32_t nnz = *nnzmap++;
257           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
258           if XNN_LIKELY(nnz != 0) {
259             do {
260               const intptr_t diff = *dmap++;
261               const float32x4_t va0123 = vld1q_f32(a);
262               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
263               const float32x4_t vb = vld1q_dup_f32(w); w += 1;
264               vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
265             } while (--nnz != 0);
266           }
267           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
268 
269           vout0123 = vmaxq_f32(vout0123, vmin);
270 
271           vst1q_f32(c + 0, vout0123);
272           c += m;
273           j -= 1;
274         } while (j != 0);
275       }
276       c -= m * n;
277       c += 4;
278       a += 4;
279     }
280     if (i & 2) {
281       const float*restrict w = weights;
282       const int32_t* dmap = widx_dmap;
283       const uint32_t* nnzmap = nidx_nnzmap;
284       size_t j = n;
285       while (j >= 2) {
286         uint32_t nnz = *nnzmap++;
287         float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
288         float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
289         if XNN_LIKELY(nnz != 0) {
290           do {
291             const intptr_t diff = *dmap++;
292             const float32x2_t va01 = vld1_f32(a);
293             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
294             const float32x2_t vb = vld1_f32(w); w += 2;
295 
296             vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
297             vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
298           } while (--nnz != 0);
299         }
300         float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
301         float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
302 
303         vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
304         vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
305 
306         vst1_f32(c + 0 * m + 0, vout01c0);
307         vst1_f32(c + 1 * m + 0, vout01c1);
308         c += 2 * m;
309         j -= 2;
310       }
311 
312       // clean up loop, fall back to nr=1
313       if XNN_UNLIKELY(j != 0) {
314         do {
315           uint32_t nnz = *nnzmap++;
316           float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
317           if XNN_LIKELY(nnz != 0) {
318             do {
319               const intptr_t diff = *dmap++;
320               const float32x2_t va01 = vld1_f32(a);
321               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
322               const float32x2_t vb = vld1_dup_f32(w); w += 1;
323               vacc01 = vfma_f32(vacc01, va01, vb);
324             } while (--nnz != 0);
325           }
326           float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
327           vout01 = vmax_f32(vout01, vget_low_f32(vmin));
328 
329           vst1_f32(c, vout01);
330           c += m;
331           j -= 1;
332         } while (j != 0);
333       }
334       c -= m * n;
335       c += 2;
336       a += 2;
337     }
338     if (i & 1) {
339       const float*restrict w = weights;
340       const int32_t* dmap = widx_dmap;
341       const uint32_t* nnzmap = nidx_nnzmap;
342       size_t j = n;
343       while (j >= 2) {
344         uint32_t nnz = *nnzmap++;
345         float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
346         float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
347         if XNN_LIKELY(nnz != 0) {
348           do {
349             const intptr_t diff = *dmap++;
350             const float32x2_t va0 = vld1_dup_f32(a);
351             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
352             const float32x2_t vb = vld1_f32(w); w += 2;
353 
354             vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
355             vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
356           } while (--nnz != 0);
357         }
358         float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
359         float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
360 
361         vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
362         vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
363 
364         vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
365         vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
366         c += 2 * m;
367         j -= 2;
368       }
369 
370       // clean up loop, fall back to nr=1
371       if XNN_UNLIKELY(j != 0) {
372         do {
373           uint32_t nnz = *nnzmap++;
374           float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
375           if XNN_LIKELY(nnz != 0) {
376             do {
377               const intptr_t diff = *dmap++;
378               const float32x2_t va0 = vld1_dup_f32(a);
379               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
380               const float32x2_t vb = vld1_dup_f32(w); w += 1;
381               vacc0 = vfma_f32(vacc0, va0, vb);
382             } while (--nnz != 0);
383           }
384           float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
385           vout0 = vmax_f32(vout0, vget_low_f32(vmin));
386 
387           vst1_lane_f32(c, vout0, 1);
388           c += m;
389           j -= 1;
390         } while (j != 0);
391       }
392       c -= m * n;
393       c += 1;
394       a += 1;
395     }
396     }
397 }
398