1 // Auto-generated file. Do not edit!
2 // Template: src/f16-spmm/neonfp16arith.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f16_spmm_ukernel_32x1__neonfp16arith(uint32_t m,uint32_t n,const void * restrict input,const void * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,void * restrict output,const struct xnn_f16_output_params params[restrict static1])17 void xnn_f16_spmm_ukernel_32x1__neonfp16arith(
18 uint32_t m,
19 uint32_t n,
20 const void*restrict input,
21 const void*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 void*restrict output,
25 const struct xnn_f16_output_params params[restrict static 1])
26 {
27 assert(m != 0);
28
29 const __fp16*restrict a = input;
30 __fp16*restrict c = output;
31
32 const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
33 const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
34 const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
35
36 size_t i = m;
37 while XNN_LIKELY(i >= 32) {
38 const __fp16*restrict w = weights;
39 const int32_t* dmap = widx_dmap;
40 const uint32_t* nnzmap = nidx_nnzmap;
41 size_t j = n;
42 do {
43 uint32_t nnz = *nnzmap++;
44 float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
45 float16x8_t vacc89ABCDEF = vacc01234567;
46 float16x8_t vaccGHIJKLMN = vacc01234567;
47 float16x8_t vaccOPQRSTUV = vacc01234567;
48 if XNN_LIKELY(nnz != 0) {
49 do {
50 const intptr_t diff = *dmap++;
51 const float16x8_t va01234567 = vld1q_f16(a);
52 const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
53 const float16x8_t vaGHIJKLMN = vld1q_f16(a + 16);
54 const float16x8_t vaOPQRSTUV = vld1q_f16(a + 24);
55 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
56 const float16x8_t vb = vld1q_dup_f16(w); w += 1;
57 vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
58 vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
59 vaccGHIJKLMN = vfmaq_f16(vaccGHIJKLMN, vaGHIJKLMN, vb);
60 vaccOPQRSTUV = vfmaq_f16(vaccOPQRSTUV, vaOPQRSTUV, vb);
61 } while (--nnz != 0);
62 }
63 float16x8_t vout01234567 = vmulq_f16(vacc01234567, vscale);
64 float16x8_t vout89ABCDEF = vmulq_f16(vacc89ABCDEF, vscale);
65 float16x8_t voutGHIJKLMN = vmulq_f16(vaccGHIJKLMN, vscale);
66 float16x8_t voutOPQRSTUV = vmulq_f16(vaccOPQRSTUV, vscale);
67 vout01234567 = vminq_f16(vout01234567, vmax);
68 vout89ABCDEF = vminq_f16(vout89ABCDEF, vmax);
69 voutGHIJKLMN = vminq_f16(voutGHIJKLMN, vmax);
70 voutOPQRSTUV = vminq_f16(voutOPQRSTUV, vmax);
71 vout01234567 = vmaxq_f16(vout01234567, vmin);
72 vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
73 voutGHIJKLMN = vmaxq_f16(voutGHIJKLMN, vmin);
74 voutOPQRSTUV = vmaxq_f16(voutOPQRSTUV, vmin);
75 vst1q_f16(c, vout01234567);
76 vst1q_f16(c + 8, vout89ABCDEF);
77 vst1q_f16(c + 16, voutGHIJKLMN);
78 vst1q_f16(c + 24, voutOPQRSTUV);
79 c += m;
80 } while (--j != 0);
81 c -= m * n;
82 c += 32;
83 a += 32;
84 i -= 32;
85 }
86 if XNN_UNLIKELY(i != 0) {
87 if (i & 16) {
88 const __fp16*restrict w = weights;
89 const int32_t* dmap = widx_dmap;
90 const uint32_t* nnzmap = nidx_nnzmap;
91 size_t j = n;
92 do {
93 uint32_t nnz = *nnzmap++;
94 float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
95 float16x8_t vacc89ABCDEF = vacc01234567;
96 if XNN_LIKELY(nnz != 0) {
97 do {
98 const intptr_t diff = *dmap++;
99 const float16x8_t va01234567 = vld1q_f16(a);
100 const float16x8_t va89ABCDEF = vld1q_f16(a + 8);
101 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
102 const float16x8_t vb = vld1q_dup_f16(w); w += 1;
103 vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
104 vacc89ABCDEF = vfmaq_f16(vacc89ABCDEF, va89ABCDEF, vb);
105 } while (--nnz != 0);
106 }
107 float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
108 float16x8_t vout89ABCDEF = vminq_f16(vacc89ABCDEF, vmax);
109 vout01234567 = vmaxq_f16(vout01234567, vmin);
110 vout89ABCDEF = vmaxq_f16(vout89ABCDEF, vmin);
111 vst1q_f16(c, vout01234567);
112 vst1q_f16(c + 8, vout89ABCDEF);
113 c += m;
114 } while (--j != 0);
115 c -= m * n;
116 c += 16;
117 a += 16;
118 }
119 if (i & 8) {
120 const __fp16*restrict w = weights;
121 const int32_t* dmap = widx_dmap;
122 const uint32_t* nnzmap = nidx_nnzmap;
123 size_t j = n;
124 do {
125 uint32_t nnz = *nnzmap++;
126 float16x8_t vacc01234567 = vld1q_dup_f16(w); w += 1;
127 if XNN_LIKELY(nnz != 0) {
128 do {
129 const intptr_t diff = *dmap++;
130 const float16x8_t va01234567 = vld1q_f16(a);
131 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
132 const float16x8_t vb = vld1q_dup_f16(w); w += 1;
133 vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
134 } while (--nnz != 0);
135 }
136 float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
137 vout01234567 = vmaxq_f16(vout01234567, vmin);
138 vst1q_f16(c, vout01234567);
139 c += m;
140 } while (--j != 0);
141 c -= m * n;
142 c += 8;
143 a += 8;
144 }
145 if (i & 4) {
146 const __fp16*restrict w = weights;
147 const int32_t* dmap = widx_dmap;
148 const uint32_t* nnzmap = nidx_nnzmap;
149 size_t j = n;
150 do {
151 uint32_t nnz = *nnzmap++;
152 float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
153 if XNN_LIKELY(nnz != 0) {
154 do {
155 const intptr_t diff = *dmap++;
156 const float16x4_t va0123 = vld1_f16(a);
157 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
158 const float16x4_t vb = vld1_dup_f16(w); w += 1;
159 vacc0123 = vfma_f16(vacc0123, va0123, vb);
160 } while (--nnz != 0);
161 }
162 float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
163 vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
164 vst1_f16(c, vout0123);
165 c += m;
166 } while (--j != 0);
167 c -= m * n;
168 c += 4;
169 a += 4;
170 }
171 if (i & 2) {
172 const __fp16*restrict w = weights;
173 const int32_t* dmap = widx_dmap;
174 const uint32_t* nnzmap = nidx_nnzmap;
175 size_t j = n;
176 do {
177 uint32_t nnz = *nnzmap++;
178 float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
179 if XNN_LIKELY(nnz != 0) {
180 do {
181 const intptr_t diff = *dmap++;
182 const float16x4_t va01 = vreinterpret_f32_f16(vld1_dup_f32(__builtin_assume_aligned(a, 1)));
183 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
184 const float16x4_t vb = vld1_dup_f16(w); w += 1;
185 vacc01 = vfma_f16(vacc01, va01, vb);
186 } while (--nnz != 0);
187 }
188 float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
189 vout01 = vmax_f16(vout01, vget_low_f16(vmin));
190 vst1_lane_f32(__builtin_assume_aligned(c, 1), vreinterpret_f16_f32(vout01), 0);
191 c += m;
192 } while (--j != 0);
193 c -= m * n;
194 c += 2;
195 a += 2;
196 }
197 if (i & 1) {
198 const __fp16*restrict w = weights;
199 const int32_t* dmap = widx_dmap;
200 const uint32_t* nnzmap = nidx_nnzmap;
201 size_t j = n;
202 do {
203 uint32_t nnz = *nnzmap++;
204 float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
205 if XNN_LIKELY(nnz != 0) {
206 do {
207 const intptr_t diff = *dmap++;
208 const float16x4_t va0 = vld1_dup_f16(a);
209 a = (const __fp16*restrict) ((uintptr_t) a + (uintptr_t) diff);
210 const float16x4_t vb = vld1_dup_f16(w); w += 1;
211 vacc0 = vfma_f16(vacc0, va0, vb);
212 } while (--nnz != 0);
213 }
214 float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
215 vout0 = vmax_f16(vout0, vget_low_f16(vmin));
216 vst1_lane_f16(c, vout0, 0);
217 c += m;
218 } while (--j != 0);
219 c -= m * n;
220 c += 1;
221 a += 1;
222 }
223 }
224 }
225