1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_8x1__neonfma_x2(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_8x1__neonfma_x2(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 8 * sizeof(float);
35 while XNN_LIKELY(mc >= 8 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 do {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc0123x1 = vmovq_n_f32(0.0f);
44 float32x4_t vacc4567x0 = vacc0123x0;
45 float32x4_t vacc4567x1 = vmovq_n_f32(0.0f);
46 for (; nnz >= 2; nnz -= 2) {
47 const intptr_t diff0 = dmap[0];
48 const intptr_t diff1 = dmap[1];
49 dmap += 2;
50 const float32x4_t vi0123x0 = vld1q_f32(input);
51 const float32x4_t vi4567x0 = vld1q_f32(input + 4);
52 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff0);
53 __builtin_prefetch(input + 16);
54 const float32x4_t vw0 = vld1q_dup_f32(w); w += 1;
55 __builtin_prefetch(w + 32);
56 vacc0123x0 = vfmaq_f32(vacc0123x0, vi0123x0, vw0);
57 vacc4567x0 = vfmaq_f32(vacc4567x0, vi4567x0, vw0);
58 const float32x4_t vi0123x1 = vld1q_f32(input);
59 const float32x4_t vi4567x1 = vld1q_f32(input + 4);
60 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff1);
61 __builtin_prefetch(input + 16);
62 const float32x4_t vw1 = vld1q_dup_f32(w); w += 1;
63 __builtin_prefetch(w + 32);
64 vacc0123x1 = vfmaq_f32(vacc0123x1, vi0123x1, vw1);
65 vacc4567x1 = vfmaq_f32(vacc4567x1, vi4567x1, vw1);
66 }
67 float32x4_t vacc0123 = vacc0123x0;
68 float32x4_t vacc4567 = vacc4567x0;
69 vacc0123 = vaddq_f32(vacc0123, vacc0123x1);
70 vacc4567 = vaddq_f32(vacc4567, vacc4567x1);
71 if XNN_LIKELY(nnz != 0) {
72 do {
73 const intptr_t diff = *dmap++;
74 const float32x4_t vi0123 = vld1q_f32(input);
75 const float32x4_t vi4567 = vld1q_f32(input + 4);
76 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
77 __builtin_prefetch(input + 16);
78 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
79 __builtin_prefetch(w + 32);
80 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
81 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
82 } while (--nnz != 0);
83 }
84 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
85 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
86 vout0123 = vmaxq_f32(vout0123, vmin);
87 vout4567 = vmaxq_f32(vout4567, vmin);
88 vst1q_f32(output, vout0123);
89 vst1q_f32(output + 4, vout4567);
90 output = (float*restrict) ((uintptr_t) output + output_stride);
91 } while (--n != 0);
92 output = (float*restrict) ((uintptr_t) output - output_decrement);
93 input += 8;
94 mc -= 8 * sizeof(float);
95 }
96 if XNN_UNLIKELY(mc != 0) {
97 output_decrement += 4 * sizeof(float);
98 if (mc & (4 * sizeof(float))) {
99 const float*restrict w = weights;
100 const int32_t* dmap = widx_dmap;
101 const uint32_t* nnzmap = nidx_nnzmap;
102 size_t n = nc;
103 do {
104 uint32_t nnz = *nnzmap++;
105 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
106 if XNN_LIKELY(nnz != 0) {
107 do {
108 const intptr_t diff = *dmap++;
109 const float32x4_t vi0123 = vld1q_f32(input);
110 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
111 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
112 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
113 } while (--nnz != 0);
114 }
115 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
116 vout0123 = vmaxq_f32(vout0123, vmin);
117 vst1q_f32(output, vout0123);
118 output = (float*restrict) ((uintptr_t) output + output_stride);
119 } while (--n != 0);
120 output = (float*restrict) ((uintptr_t) output - output_decrement);
121 input += 4;
122 }
123 output_decrement += 2 * sizeof(float);
124 if (mc & (2 * sizeof(float))) {
125 const float*restrict w = weights;
126 const int32_t* dmap = widx_dmap;
127 const uint32_t* nnzmap = nidx_nnzmap;
128 size_t n = nc;
129 do {
130 uint32_t nnz = *nnzmap++;
131 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
132 if XNN_LIKELY(nnz != 0) {
133 do {
134 const intptr_t diff = *dmap++;
135 const float32x2_t vi01 = vld1_f32(input);
136 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
137 const float32x2_t vw = vld1_dup_f32(w); w += 1;
138 vacc01 = vfma_f32(vacc01, vi01, vw);
139 } while (--nnz != 0);
140 }
141 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
142 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
143 vst1_f32(output, vout01);
144 output = (float*restrict) ((uintptr_t) output + output_stride);
145 } while (--n != 0);
146 output = (float*restrict) ((uintptr_t) output - output_decrement);
147 input += 2;
148 }
149 output_decrement += 1 * sizeof(float);
150 if (mc & (1 * sizeof(float))) {
151 const float*restrict w = weights;
152 const int32_t* dmap = widx_dmap;
153 const uint32_t* nnzmap = nidx_nnzmap;
154 size_t n = nc;
155 do {
156 uint32_t nnz = *nnzmap++;
157 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
158 if XNN_LIKELY(nnz != 0) {
159 do {
160 const intptr_t diff = *dmap++;
161 const float32x2_t vi0 = vld1_dup_f32(input);
162 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
163 const float32x2_t vw = vld1_dup_f32(w); w += 1;
164 vacc0 = vfma_f32(vacc0, vi0, vw);
165 } while (--nnz != 0);
166 }
167 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
168 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
169 vst1_lane_f32(output, vout0, 0);
170 output = (float*restrict) ((uintptr_t) output + output_stride);
171 } while (--n != 0);
172 output = (float*restrict) ((uintptr_t) output - output_decrement);
173 input += 1;
174 }
175 }
176 }
177