1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-pipelined.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_32x1__neonfma_pipelined(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x1__neonfma_pipelined(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35 while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 float32x4_t vw = vld1q_dup_f32(w); w += 1;
40 intptr_t diff = *dmap++;
41 float32x4_t vi0123 = vld1q_f32(input);
42 float32x4_t vi4567 = vld1q_f32(input + 4);
43 float32x4_t vi89AB = vld1q_f32(input + 8);
44 float32x4_t viCDEF = vld1q_f32(input + 12);
45 float32x4_t viGHIJ = vld1q_f32(input + 16);
46 float32x4_t viKLMN = vld1q_f32(input + 20);
47 float32x4_t viOPQR = vld1q_f32(input + 24);
48 float32x4_t viSTUV = vld1q_f32(input + 28);
49 size_t n = nc;
50 do {
51 uint32_t nnz = *nnzmap++;
52 float32x4_t vacc0123 = vw;
53 float32x4_t vacc4567 = vw;
54 float32x4_t vacc89AB = vw;
55 float32x4_t vaccCDEF = vw;
56 float32x4_t vaccGHIJ = vw;
57 float32x4_t vaccKLMN = vw;
58 float32x4_t vaccOPQR = vw;
59 float32x4_t vaccSTUV = vw;
60 vw = vld1q_dup_f32(w); w += 1;
61 if XNN_LIKELY(nnz != 0) {
62 do {
63 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
64 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
65 vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vw);
66 vaccCDEF = vfmaq_f32(vaccCDEF, viCDEF, vw);
67 vaccGHIJ = vfmaq_f32(vaccGHIJ, viGHIJ, vw);
68 vaccKLMN = vfmaq_f32(vaccKLMN, viKLMN, vw);
69 vaccOPQR = vfmaq_f32(vaccOPQR, viOPQR, vw);
70 vaccSTUV = vfmaq_f32(vaccSTUV, viSTUV, vw);
71 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
72 __builtin_prefetch(input + 16);
73 __builtin_prefetch(input + 32);
74 diff = *dmap++;
75 vw = vld1q_dup_f32(w); w += 1;
76 __builtin_prefetch(w + 32);
77 vi0123 = vld1q_f32(input);
78 vi4567 = vld1q_f32(input + 4);
79 vi89AB = vld1q_f32(input + 8);
80 viCDEF = vld1q_f32(input + 12);
81 viGHIJ = vld1q_f32(input + 16);
82 viKLMN = vld1q_f32(input + 20);
83 viOPQR = vld1q_f32(input + 24);
84 viSTUV = vld1q_f32(input + 28);
85 } while (--nnz != 0);
86 }
87 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
88 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
89 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
90 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
91 float32x4_t voutGHIJ = vminq_f32(vaccGHIJ, vmax);
92 float32x4_t voutKLMN = vminq_f32(vaccKLMN, vmax);
93 float32x4_t voutOPQR = vminq_f32(vaccOPQR, vmax);
94 float32x4_t voutSTUV = vminq_f32(vaccSTUV, vmax);
95 vout0123 = vmaxq_f32(vout0123, vmin);
96 vout4567 = vmaxq_f32(vout4567, vmin);
97 vout89AB = vmaxq_f32(vout89AB, vmin);
98 voutCDEF = vmaxq_f32(voutCDEF, vmin);
99 voutGHIJ = vmaxq_f32(voutGHIJ, vmin);
100 voutKLMN = vmaxq_f32(voutKLMN, vmin);
101 voutOPQR = vmaxq_f32(voutOPQR, vmin);
102 voutSTUV = vmaxq_f32(voutSTUV, vmin);
103 vst1q_f32(output, vout0123);
104 vst1q_f32(output + 4, vout4567);
105 vst1q_f32(output + 8, vout89AB);
106 vst1q_f32(output + 12, voutCDEF);
107 vst1q_f32(output + 16, voutGHIJ);
108 vst1q_f32(output + 20, voutKLMN);
109 vst1q_f32(output + 24, voutOPQR);
110 vst1q_f32(output + 28, voutSTUV);
111 output = (float*restrict) ((uintptr_t) output + output_stride);
112 } while (--n != 0);
113 output = (float*restrict) ((uintptr_t) output - output_decrement);
114 input += 32;
115 mc -= 32 * sizeof(float);
116 }
117 if XNN_UNLIKELY(mc != 0) {
118 output_decrement += 16 * sizeof(float);
119 if (mc & (16 * sizeof(float))) {
120 const float*restrict w = weights;
121 const int32_t* dmap = widx_dmap;
122 const uint32_t* nnzmap = nidx_nnzmap;
123 size_t n = nc;
124 do {
125 uint32_t nnz = *nnzmap++;
126 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
127 float32x4_t vacc4567 = vacc0123;
128 float32x4_t vacc89AB = vacc0123;
129 float32x4_t vaccCDEF = vacc0123;
130 if XNN_LIKELY(nnz != 0) {
131 do {
132 const intptr_t diff = *dmap++;
133 const float32x4_t vi0123 = vld1q_f32(input);
134 const float32x4_t vi4567 = vld1q_f32(input + 4);
135 const float32x4_t vi89AB = vld1q_f32(input + 8);
136 const float32x4_t viCDEF = vld1q_f32(input + 12);
137 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
138 __builtin_prefetch(input + 16);
139 __builtin_prefetch(input + 32);
140 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
141 __builtin_prefetch(w + 32);
142 vacc0123 = vfmaq_f32(vacc0123, vi0123, vb);
143 vacc4567 = vfmaq_f32(vacc4567, vi4567, vb);
144 vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vb);
145 vaccCDEF = vfmaq_f32(vaccCDEF, viCDEF, vb);
146 } while (--nnz != 0);
147 }
148 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
149 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
150 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
151 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
152 vout0123 = vmaxq_f32(vout0123, vmin);
153 vout4567 = vmaxq_f32(vout4567, vmin);
154 vout89AB = vmaxq_f32(vout89AB, vmin);
155 voutCDEF = vmaxq_f32(voutCDEF, vmin);
156 vst1q_f32(output, vout0123);
157 vst1q_f32(output + 4, vout4567);
158 vst1q_f32(output + 8, vout89AB);
159 vst1q_f32(output + 12, voutCDEF);
160 output = (float*restrict) ((uintptr_t) output + output_stride);
161 } while (--n != 0);
162 output = (float*restrict) ((uintptr_t) output - output_decrement);
163 input += 16;
164 }
165 output_decrement += 8 * sizeof(float);
166 if (mc & (8 * sizeof(float))) {
167 const float*restrict w = weights;
168 const int32_t* dmap = widx_dmap;
169 const uint32_t* nnzmap = nidx_nnzmap;
170 size_t n = nc;
171 do {
172 uint32_t nnz = *nnzmap++;
173 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
174 float32x4_t vacc4567 = vacc0123;
175 if XNN_LIKELY(nnz != 0) {
176 do {
177 const intptr_t diff = *dmap++;
178 const float32x4_t vi0123 = vld1q_f32(input);
179 const float32x4_t vi4567 = vld1q_f32(input + 4);
180 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
181 __builtin_prefetch(input + 16);
182 __builtin_prefetch(input + 32);
183 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
184 __builtin_prefetch(w + 32);
185 vacc0123 = vfmaq_f32(vacc0123, vi0123, vb);
186 vacc4567 = vfmaq_f32(vacc4567, vi4567, vb);
187 } while (--nnz != 0);
188 }
189 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
190 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
191 vout0123 = vmaxq_f32(vout0123, vmin);
192 vout4567 = vmaxq_f32(vout4567, vmin);
193 vst1q_f32(output, vout0123);
194 vst1q_f32(output + 4, vout4567);
195 output = (float*restrict) ((uintptr_t) output + output_stride);
196 } while (--n != 0);
197 output = (float*restrict) ((uintptr_t) output - output_decrement);
198 input += 8;
199 }
200 output_decrement += 4 * sizeof(float);
201 if (mc & (4 * sizeof(float))) {
202 const float*restrict w = weights;
203 const int32_t* dmap = widx_dmap;
204 const uint32_t* nnzmap = nidx_nnzmap;
205 size_t n = nc;
206 do {
207 uint32_t nnz = *nnzmap++;
208 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
209 if XNN_LIKELY(nnz != 0) {
210 do {
211 const intptr_t diff = *dmap++;
212 const float32x4_t vi0123 = vld1q_f32(input);
213 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
214 __builtin_prefetch(input + 16);
215 __builtin_prefetch(input + 32);
216 const float32x4_t vb = vld1q_dup_f32(w); w += 1;
217 __builtin_prefetch(w + 32);
218 vacc0123 = vfmaq_f32(vacc0123, vi0123, vb);
219 } while (--nnz != 0);
220 }
221 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
222 vout0123 = vmaxq_f32(vout0123, vmin);
223 vst1q_f32(output, vout0123);
224 output = (float*restrict) ((uintptr_t) output + output_stride);
225 } while (--n != 0);
226 output = (float*restrict) ((uintptr_t) output - output_decrement);
227 input += 4;
228 }
229 output_decrement += 2 * sizeof(float);
230 if (mc & (2 * sizeof(float))) {
231 const float*restrict w = weights;
232 const int32_t* dmap = widx_dmap;
233 const uint32_t* nnzmap = nidx_nnzmap;
234 size_t n = nc;
235 do {
236 uint32_t nnz = *nnzmap++;
237 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
238 if XNN_LIKELY(nnz != 0) {
239 do {
240 const intptr_t diff = *dmap++;
241 const float32x2_t vi01 = vld1_f32(input);
242 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
243 __builtin_prefetch(input + 16);
244 __builtin_prefetch(input + 32);
245 const float32x2_t vb = vld1_dup_f32(w); w += 1;
246 __builtin_prefetch(w + 32);
247 vacc01 = vfma_f32(vacc01, vi01, vb);
248 } while (--nnz != 0);
249 }
250 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
251 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
252 vst1_f32(output, vout01);
253 output = (float*restrict) ((uintptr_t) output + output_stride);
254 } while (--n != 0);
255 output = (float*restrict) ((uintptr_t) output - output_decrement);
256 input += 2;
257 }
258 output_decrement += 1 * sizeof(float);
259 if (mc & (1 * sizeof(float))) {
260 const float*restrict w = weights;
261 const int32_t* dmap = widx_dmap;
262 const uint32_t* nnzmap = nidx_nnzmap;
263 size_t n = nc;
264 do {
265 uint32_t nnz = *nnzmap++;
266 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
267 if XNN_LIKELY(nnz != 0) {
268 do {
269 const intptr_t diff = *dmap++;
270 const float32x2_t vi0 = vld1_dup_f32(input);
271 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
272 __builtin_prefetch(input + 16);
273 __builtin_prefetch(input + 32);
274 const float32x2_t vb = vld1_dup_f32(w); w += 1;
275 __builtin_prefetch(w + 32);
276 vacc0 = vfma_f32(vacc0, vi0, vb);
277 } while (--nnz != 0);
278 }
279 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
280 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
281 vst1_lane_f32(output, vout0, 0);
282 output = (float*restrict) ((uintptr_t) output + output_stride);
283 } while (--n != 0);
284 output = (float*restrict) ((uintptr_t) output - output_decrement);
285 input += 1;
286 }
287 }
288 }
289