• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon-blocked.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_ukernel_8x4__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_8x4__neonfma(
18     uint32_t m,
19     uint32_t n,
20     const float*restrict a,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict c,
25     const union xnn_f32_output_params params[restrict static 1])
26 {
27   assert(m != 0);
28 
29   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
30   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
31   size_t i = m;
32   while XNN_LIKELY(i >= 8) {
33     const float*restrict w = weights;
34     const int32_t* dmap = widx_dmap;
35     const uint32_t* nnzmap = nidx_nnzmap;
36     size_t j = n;
37     while (j >= 4) {
38       uint32_t nnz = *nnzmap++;
39       float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
40       float32x4_t vacc4567c0 = vacc0123c0;
41       float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
42       float32x4_t vacc4567c1 = vacc0123c1;
43       float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
44       float32x4_t vacc4567c2 = vacc0123c2;
45       float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
46       float32x4_t vacc4567c3 = vacc0123c3;
47       if XNN_LIKELY(nnz != 0) {
48         do {
49           const intptr_t diff = *dmap++;
50           const float32x4_t va0123 = vld1q_f32(a);
51           const float32x4_t va4567 = vld1q_f32(a + 4);
52           a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
53           const float32x4_t vb = vld1q_f32(w); w += 4;
54 
55           vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
56           vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
57           vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
58           vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
59           vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
60           vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
61           vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
62           vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
63         } while (--nnz != 0);
64       }
65       float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
66       float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
67       float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
68       float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
69       float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
70       float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
71       float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
72       float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
73 
74       vout0123c0 = vmaxq_f32(vout0123c0, vmin);
75       vout4567c0 = vmaxq_f32(vout4567c0, vmin);
76       vout0123c1 = vmaxq_f32(vout0123c1, vmin);
77       vout4567c1 = vmaxq_f32(vout4567c1, vmin);
78       vout0123c2 = vmaxq_f32(vout0123c2, vmin);
79       vout4567c2 = vmaxq_f32(vout4567c2, vmin);
80       vout0123c3 = vmaxq_f32(vout0123c3, vmin);
81       vout4567c3 = vmaxq_f32(vout4567c3, vmin);
82 
83       vst1q_f32(c + 0 * m + 0, vout0123c0);
84       vst1q_f32(c + 0 * m + 4, vout4567c0);
85       vst1q_f32(c + 1 * m + 0, vout0123c1);
86       vst1q_f32(c + 1 * m + 4, vout4567c1);
87       vst1q_f32(c + 2 * m + 0, vout0123c2);
88       vst1q_f32(c + 2 * m + 4, vout4567c2);
89       vst1q_f32(c + 3 * m + 0, vout0123c3);
90       vst1q_f32(c + 3 * m + 4, vout4567c3);
91       c += 4 * m;
92       j -= 4;
93     }
94 
95     // clean up loop, fall back to nr=1
96     if XNN_UNLIKELY(j != 0) {
97       do {
98         uint32_t nnz = *nnzmap++;
99         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
100         float32x4_t vacc4567 = vacc0123;
101         if XNN_LIKELY(nnz != 0) {
102           do {
103             const intptr_t diff = *dmap++;
104             const float32x4_t va0123 = vld1q_f32(a);
105             const float32x4_t va4567 = vld1q_f32(a + 4);
106             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
107             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
108             vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
109             vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
110           } while (--nnz != 0);
111         }
112         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
113         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
114 
115         vout0123 = vmaxq_f32(vout0123, vmin);
116         vout4567 = vmaxq_f32(vout4567, vmin);
117 
118         vst1q_f32(c + 0, vout0123);
119         vst1q_f32(c + 4, vout4567);
120         c += m;
121         j -= 1;
122       } while (j != 0);
123     }
124     c -= m * n;
125     c += 8;
126     a += 8;
127     i -= 8;
128   }
129   if XNN_UNLIKELY(i != 0) {
130     if (i & 4) {
131       const float*restrict w = weights;
132       const int32_t* dmap = widx_dmap;
133       const uint32_t* nnzmap = nidx_nnzmap;
134       size_t j = n;
135       while (j >= 4) {
136         uint32_t nnz = *nnzmap++;
137         float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
138         float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
139         float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
140         float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
141         if XNN_LIKELY(nnz != 0) {
142           do {
143             const intptr_t diff = *dmap++;
144             const float32x4_t va0123 = vld1q_f32(a);
145             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
146             const float32x4_t vb = vld1q_f32(w); w += 4;
147 
148             vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
149             vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
150             vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
151             vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
152           } while (--nnz != 0);
153         }
154         float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
155         float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
156         float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
157         float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
158 
159         vout0123c0 = vmaxq_f32(vout0123c0, vmin);
160         vout0123c1 = vmaxq_f32(vout0123c1, vmin);
161         vout0123c2 = vmaxq_f32(vout0123c2, vmin);
162         vout0123c3 = vmaxq_f32(vout0123c3, vmin);
163 
164         vst1q_f32(c + 0 * m + 0, vout0123c0);
165         vst1q_f32(c + 1 * m + 0, vout0123c1);
166         vst1q_f32(c + 2 * m + 0, vout0123c2);
167         vst1q_f32(c + 3 * m + 0, vout0123c3);
168         c += 4 * m;
169         j -= 4;
170       }
171 
172       // clean up loop, fall back to nr=1
173       if XNN_UNLIKELY(j != 0) {
174         do {
175           uint32_t nnz = *nnzmap++;
176           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
177           if XNN_LIKELY(nnz != 0) {
178             do {
179               const intptr_t diff = *dmap++;
180               const float32x4_t va0123 = vld1q_f32(a);
181               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
182               const float32x4_t vb = vld1q_dup_f32(w); w += 1;
183               vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
184             } while (--nnz != 0);
185           }
186           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
187 
188           vout0123 = vmaxq_f32(vout0123, vmin);
189 
190           vst1q_f32(c + 0, vout0123);
191           c += m;
192           j -= 1;
193         } while (j != 0);
194       }
195       c -= m * n;
196       c += 4;
197       a += 4;
198     }
199     if (i & 2) {
200       const float*restrict w = weights;
201       const int32_t* dmap = widx_dmap;
202       const uint32_t* nnzmap = nidx_nnzmap;
203       size_t j = n;
204       while (j >= 4) {
205         uint32_t nnz = *nnzmap++;
206         float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
207         float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
208         float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
209         float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
210         if XNN_LIKELY(nnz != 0) {
211           do {
212             const intptr_t diff = *dmap++;
213             const float32x2_t va01 = vld1_f32(a);
214             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
215             const float32x4_t vb = vld1q_f32(w); w += 4;
216 
217             vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
218             vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
219             vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
220             vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
221           } while (--nnz != 0);
222         }
223         float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
224         float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
225         float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
226         float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
227 
228         vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
229         vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
230         vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
231         vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
232 
233         vst1_f32(c + 0 * m + 0, vout01c0);
234         vst1_f32(c + 1 * m + 0, vout01c1);
235         vst1_f32(c + 2 * m + 0, vout01c2);
236         vst1_f32(c + 3 * m + 0, vout01c3);
237         c += 4 * m;
238         j -= 4;
239       }
240 
241       // clean up loop, fall back to nr=1
242       if XNN_UNLIKELY(j != 0) {
243         do {
244           uint32_t nnz = *nnzmap++;
245           float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
246           if XNN_LIKELY(nnz != 0) {
247             do {
248               const intptr_t diff = *dmap++;
249               const float32x2_t va01 = vld1_f32(a);
250               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
251               const float32x2_t vb = vld1_dup_f32(w); w += 1;
252               vacc01 = vfma_f32(vacc01, va01, vb);
253             } while (--nnz != 0);
254           }
255           float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
256           vout01 = vmax_f32(vout01, vget_low_f32(vmin));
257 
258           vst1_f32(c, vout01);
259           c += m;
260           j -= 1;
261         } while (j != 0);
262       }
263       c -= m * n;
264       c += 2;
265       a += 2;
266     }
267     if (i & 1) {
268       const float*restrict w = weights;
269       const int32_t* dmap = widx_dmap;
270       const uint32_t* nnzmap = nidx_nnzmap;
271       size_t j = n;
272       while (j >= 4) {
273         uint32_t nnz = *nnzmap++;
274         float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
275         float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
276         float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
277         float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
278         if XNN_LIKELY(nnz != 0) {
279           do {
280             const intptr_t diff = *dmap++;
281             const float32x2_t va0 = vld1_dup_f32(a);
282             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
283             const float32x4_t vb = vld1q_f32(w); w += 4;
284 
285             vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
286             vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
287             vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
288             vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
289           } while (--nnz != 0);
290         }
291         float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
292         float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
293         float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
294         float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
295 
296         vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
297         vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
298         vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
299         vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
300 
301         vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
302         vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
303         vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
304         vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
305         c += 4 * m;
306         j -= 4;
307       }
308 
309       // clean up loop, fall back to nr=1
310       if XNN_UNLIKELY(j != 0) {
311         do {
312           uint32_t nnz = *nnzmap++;
313           float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
314           if XNN_LIKELY(nnz != 0) {
315             do {
316               const intptr_t diff = *dmap++;
317               const float32x2_t va0 = vld1_dup_f32(a);
318               a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
319               const float32x2_t vb = vld1_dup_f32(w); w += 1;
320               vacc0 = vfma_f32(vacc0, va0, vb);
321             } while (--nnz != 0);
322           }
323           float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
324           vout0 = vmax_f32(vout0, vget_low_f32(vmin));
325 
326           vst1_lane_f32(c, vout0, 1);
327           c += m;
328           j -= 1;
329         } while (j != 0);
330       }
331       c -= m * n;
332       c += 1;
333       a += 1;
334     }
335     }
336 }
337