• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon-blocked.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_ukernel_12x2__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_12x2__neonfma(
18     uint32_t m,
19     uint32_t n,
20     const float*restrict a,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict c,
25     const union xnn_f32_output_params params[restrict static 1])
26 {
27   assert(m != 0);
28 
29   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
30   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
31   size_t i = m;
32   while XNN_LIKELY(i >= 12) {
33     const float*restrict w = weights;
34     const int32_t* dmap = widx_dmap;
35     const uint32_t* nnzmap = nidx_nnzmap;
36     size_t j = n;
37     while (j >= 2) {
38       uint32_t nnz = *nnzmap++;
39       float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
40       float32x4_t vacc4567c0 = vacc0123c0;
41       float32x4_t vacc89ABc0 = vacc0123c0;
42       float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
43       float32x4_t vacc4567c1 = vacc0123c1;
44       float32x4_t vacc89ABc1 = vacc0123c1;
45       if XNN_LIKELY(nnz != 0) {
46         do {
47           const intptr_t diff = *dmap++;
48           const float32x4_t va0123 = vld1q_f32(a);
49           const float32x4_t va4567 = vld1q_f32(a + 4);
50           const float32x4_t va89AB = vld1q_f32(a + 8);
51           a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
52           const float32x2_t vb = vld1_f32(w); w += 2;
53 
54           vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
55           vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
56           vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0);
57           vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
58           vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
59           vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1);
60         } while (--nnz != 0);
61       }
62       float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
63       float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
64       float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
65       float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
66       float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
67       float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
68 
69       vout0123c0 = vmaxq_f32(vout0123c0, vmin);
70       vout4567c0 = vmaxq_f32(vout4567c0, vmin);
71       vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
72       vout0123c1 = vmaxq_f32(vout0123c1, vmin);
73       vout4567c1 = vmaxq_f32(vout4567c1, vmin);
74       vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
75 
76       vst1q_f32(c + 0 * m + 0, vout0123c0);
77       vst1q_f32(c + 0 * m + 4, vout4567c0);
78       vst1q_f32(c + 0 * m + 8, vout89ABc0);
79       vst1q_f32(c + 1 * m + 0, vout0123c1);
80       vst1q_f32(c + 1 * m + 4, vout4567c1);
81       vst1q_f32(c + 1 * m + 8, vout89ABc1);
82       c += 2 * m;
83       j -= 2;
84     }
85 
86     // clean up loop, fall back to nr=1
87     if XNN_UNLIKELY(j != 0) {
88       do {
89         uint32_t nnz = *nnzmap++;
90         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
91         float32x4_t vacc4567 = vacc0123;
92         float32x4_t vacc89AB = vacc0123;
93         if XNN_LIKELY(nnz != 0) {
94           do {
95             const intptr_t diff = *dmap++;
96             const float32x4_t va0123 = vld1q_f32(a);
97             const float32x4_t va4567 = vld1q_f32(a + 4);
98             const float32x4_t va89AB = vld1q_f32(a + 8);
99             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
100             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
101             vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
102             vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
103             vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
104           } while (--nnz != 0);
105         }
106         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
107         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
108         float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
109 
110         vout0123 = vmaxq_f32(vout0123, vmin);
111         vout4567 = vmaxq_f32(vout4567, vmin);
112         vout89AB = vmaxq_f32(vout89AB, vmin);
113 
114         vst1q_f32(c + 0, vout0123);
115         vst1q_f32(c + 4, vout4567);
116         vst1q_f32(c + 8, vout89AB);
117         c += m;
118         j -= 1;
119       } while (j != 0);
120     }
121     c -= m * n;
122     c += 12;
123     a += 12;
124     i -= 12;
125   }
126   if XNN_UNLIKELY(i != 0) {
127     if (i & 8) {
128       const float*restrict w = weights;
129       const int32_t* dmap = widx_dmap;
130       const uint32_t* nnzmap = nidx_nnzmap;
131       size_t j = n;
132       while (j >= 2) {
133         uint32_t nnz = *nnzmap++;
134         float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
135         float32x4_t vacc4567c0 = vacc0123c0;
136         float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
137         float32x4_t vacc4567c1 = vacc0123c1;
138         if XNN_LIKELY(nnz != 0) {
139           do {
140             const intptr_t diff = *dmap++;
141             const float32x4_t va0123 = vld1q_f32(a);
142             const float32x4_t va4567 = vld1q_f32(a + 4);
143             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
144             const float32x2_t vb = vld1_f32(w); w += 2;
145 
146             vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
147             vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
148             vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
149             vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
150           } while (--nnz != 0);
151         }
152         float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
153         float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
154         float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
155         float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
156 
157         vout0123c0 = vmaxq_f32(vout0123c0, vmin);
158         vout4567c0 = vmaxq_f32(vout4567c0, vmin);
159         vout0123c1 = vmaxq_f32(vout0123c1, vmin);
160         vout4567c1 = vmaxq_f32(vout4567c1, vmin);
161 
162         vst1q_f32(c + 0 * m + 0, vout0123c0);
163         vst1q_f32(c + 0 * m + 4, vout4567c0);
164         vst1q_f32(c + 1 * m + 0, vout0123c1);
165         vst1q_f32(c + 1 * m + 4, vout4567c1);
166         c += 2 * m;
167         j -= 2;
168       }
169 
170       // clean up loop, fall back to nr=1
171       if XNN_UNLIKELY(j != 0) {
172         do {
173           uint32_t nnz = *nnzmap++;
174           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
175           float32x4_t vacc4567 = vacc0123;
176           if XNN_LIKELY(nnz != 0) {
177             do {
178               const intptr_t diff = *dmap++;
179               const float32x4_t va0123 = vld1q_f32(a);
180               const float32x4_t va4567 = vld1q_f32(a + 4);
181               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
182               const float32x4_t vb = vld1q_dup_f32(w); w += 1;
183               vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
184               vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
185             } while (--nnz != 0);
186           }
187           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
188           float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
189 
190           vout0123 = vmaxq_f32(vout0123, vmin);
191           vout4567 = vmaxq_f32(vout4567, vmin);
192 
193           vst1q_f32(c + 0, vout0123);
194           vst1q_f32(c + 4, vout4567);
195           c += m;
196           j -= 1;
197         } while (j != 0);
198       }
199       c -= m * n;
200       c += 8;
201       a += 8;
202     }
203     if (i & 4) {
204       const float*restrict w = weights;
205       const int32_t* dmap = widx_dmap;
206       const uint32_t* nnzmap = nidx_nnzmap;
207       size_t j = n;
208       while (j >= 2) {
209         uint32_t nnz = *nnzmap++;
210         float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
211         float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
212         if XNN_LIKELY(nnz != 0) {
213           do {
214             const intptr_t diff = *dmap++;
215             const float32x4_t va0123 = vld1q_f32(a);
216             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
217             const float32x2_t vb = vld1_f32(w); w += 2;
218 
219             vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
220             vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
221           } while (--nnz != 0);
222         }
223         float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
224         float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
225 
226         vout0123c0 = vmaxq_f32(vout0123c0, vmin);
227         vout0123c1 = vmaxq_f32(vout0123c1, vmin);
228 
229         vst1q_f32(c + 0 * m + 0, vout0123c0);
230         vst1q_f32(c + 1 * m + 0, vout0123c1);
231         c += 2 * m;
232         j -= 2;
233       }
234 
235       // clean up loop, fall back to nr=1
236       if XNN_UNLIKELY(j != 0) {
237         do {
238           uint32_t nnz = *nnzmap++;
239           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
240           if XNN_LIKELY(nnz != 0) {
241             do {
242               const intptr_t diff = *dmap++;
243               const float32x4_t va0123 = vld1q_f32(a);
244               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
245               const float32x4_t vb = vld1q_dup_f32(w); w += 1;
246               vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
247             } while (--nnz != 0);
248           }
249           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
250 
251           vout0123 = vmaxq_f32(vout0123, vmin);
252 
253           vst1q_f32(c + 0, vout0123);
254           c += m;
255           j -= 1;
256         } while (j != 0);
257       }
258       c -= m * n;
259       c += 4;
260       a += 4;
261     }
262     if (i & 2) {
263       const float*restrict w = weights;
264       const int32_t* dmap = widx_dmap;
265       const uint32_t* nnzmap = nidx_nnzmap;
266       size_t j = n;
267       while (j >= 2) {
268         uint32_t nnz = *nnzmap++;
269         float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
270         float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
271         if XNN_LIKELY(nnz != 0) {
272           do {
273             const intptr_t diff = *dmap++;
274             const float32x2_t va01 = vld1_f32(a);
275             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
276             const float32x2_t vb = vld1_f32(w); w += 2;
277 
278             vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
279             vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
280           } while (--nnz != 0);
281         }
282         float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
283         float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
284 
285         vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
286         vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
287 
288         vst1_f32(c + 0 * m + 0, vout01c0);
289         vst1_f32(c + 1 * m + 0, vout01c1);
290         c += 2 * m;
291         j -= 2;
292       }
293 
294       // clean up loop, fall back to nr=1
295       if XNN_UNLIKELY(j != 0) {
296         do {
297           uint32_t nnz = *nnzmap++;
298           float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
299           if XNN_LIKELY(nnz != 0) {
300             do {
301               const intptr_t diff = *dmap++;
302               const float32x2_t va01 = vld1_f32(a);
303               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
304               const float32x2_t vb = vld1_dup_f32(w); w += 1;
305               vacc01 = vfma_f32(vacc01, va01, vb);
306             } while (--nnz != 0);
307           }
308           float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
309           vout01 = vmax_f32(vout01, vget_low_f32(vmin));
310 
311           vst1_f32(c, vout01);
312           c += m;
313           j -= 1;
314         } while (j != 0);
315       }
316       c -= m * n;
317       c += 2;
318       a += 2;
319     }
320     if (i & 1) {
321       const float*restrict w = weights;
322       const int32_t* dmap = widx_dmap;
323       const uint32_t* nnzmap = nidx_nnzmap;
324       size_t j = n;
325       while (j >= 2) {
326         uint32_t nnz = *nnzmap++;
327         float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
328         float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
329         if XNN_LIKELY(nnz != 0) {
330           do {
331             const intptr_t diff = *dmap++;
332             const float32x2_t va0 = vld1_dup_f32(a);
333             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
334             const float32x2_t vb = vld1_f32(w); w += 2;
335 
336             vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
337             vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
338           } while (--nnz != 0);
339         }
340         float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
341         float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
342 
343         vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
344         vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
345 
346         vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
347         vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
348         c += 2 * m;
349         j -= 2;
350       }
351 
352       // clean up loop, fall back to nr=1
353       if XNN_UNLIKELY(j != 0) {
354         do {
355           uint32_t nnz = *nnzmap++;
356           float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
357           if XNN_LIKELY(nnz != 0) {
358             do {
359               const intptr_t diff = *dmap++;
360               const float32x2_t va0 = vld1_dup_f32(a);
361               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
362               const float32x2_t vb = vld1_dup_f32(w); w += 1;
363               vacc0 = vfma_f32(vacc0, va0, vb);
364             } while (--nnz != 0);
365           }
366           float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
367           vout0 = vmax_f32(vout0, vget_low_f32(vmin));
368 
369           vst1_lane_f32(c, vout0, 1);
370           c += m;
371           j -= 1;
372         } while (j != 0);
373       }
374       c -= m * n;
375       c += 1;
376       a += 1;
377     }
378     }
379 }
380