1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/conv.h>
11 #include <xnnpack/math.h>
12
13
xnn_f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2(size_t input_height,size_t input_width,size_t output_y_start,size_t output_y_end,const float * input,const float * zero,const float * weights,float * output,size_t input_padding_top,size_t output_channels,size_t output_height_stride,size_t output_channel_stride,const union xnn_f32_output_params params[restrict static1])14 void xnn_f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2(
15 size_t input_height,
16 size_t input_width,
17 size_t output_y_start,
18 size_t output_y_end,
19 const float* input,
20 const float* zero,
21 const float* weights,
22 float* output,
23 size_t input_padding_top,
24 size_t output_channels,
25 size_t output_height_stride,
26 size_t output_channel_stride,
27 const union xnn_f32_output_params params[restrict static 1])
28 {
29 assert(input_width != 0);
30 assert(output_y_end > output_y_start);
31 assert(input_padding_top <= 1);
32 assert(output_channels != 0);
33
34 const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
35 const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float);
36 const size_t output_width = (input_width + 1) / 2;
37 const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float);
38
39 // Adjustment for padding processed below
40 const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
41 const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
42 const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
43 const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
44 const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
45 float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
46 float* output1 = (float*) ((uintptr_t) output0 + output_height_stride);
47
48 if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
49 i0 = zero;
50 }
51
52 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
53 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
54
55 for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) {
56 const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
57 const size_t input_y4 = input_y2 + 2;
58 if XNN_UNPREDICTABLE(input_y2 >= input_height) {
59 i2 = zero;
60 }
61 if XNN_UNPREDICTABLE(input_y4 > input_height) {
62 i3 = zero;
63 }
64 if XNN_UNPREDICTABLE(input_y4 >= input_height) {
65 i4 = zero;
66 }
67 if XNN_UNPREDICTABLE(output_y + 2 > output_y_end) {
68 output1 = output0;
69 }
70
71 const float* w = weights;
72 size_t c = output_channels;
73 float* o0c0 = output0;
74 float* o1c0 = output1;
75 float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride);
76 float* o1c1 = (float*) ((uintptr_t) o1c0 + output_channel_stride);
77 float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride);
78 float* o1c2 = (float*) ((uintptr_t) o1c1 + output_channel_stride);
79 float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride);
80 float* o1c3 = (float*) ((uintptr_t) o1c2 + output_channel_stride);
81 do {
82 if XNN_UNPREDICTABLE(c < 2) {
83 o0c1 = o0c0;
84 o1c1 = o1c0;
85 }
86 if XNN_UNPREDICTABLE(c <= 2) {
87 o0c2 = o0c1;
88 o1c2 = o1c1;
89 }
90 if XNN_UNPREDICTABLE(c < 4) {
91 o0c3 = o0c2;
92 o1c3 = o1c2;
93 }
94
95 // viMx0 = ( iM0c2, iM0c1, iM0c0, --- )
96 float32x4_t vi0x0 = vmovq_n_f32(0.0f);
97 float32x4_t vi1x0 = vmovq_n_f32(0.0f);
98 float32x4_t vi2x0 = vmovq_n_f32(0.0f);
99 float32x4_t vi3x0 = vmovq_n_f32(0.0f);
100 float32x4_t vi4x0 = vmovq_n_f32(0.0f);
101
102 size_t iw = input_width;
103 for (; iw >= 4; iw -= 4) {
104 float32x4_t vo0x0 = vld1q_f32(w);
105 float32x4_t vo1x0 = vo0x0;
106 float32x4_t vo0x1 = vo0x0;
107 float32x4_t vo1x1 = vo0x0;
108
109 const float32x4_t vk00c0 = vld1q_f32(w + 4);
110
111 // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
112 const float32x4_t vi0x1 = vld1q_f32(i0); i0 += 4;
113 const float32x4_t vi1x1 = vld1q_f32(i1); i1 += 4;
114 const float32x4_t vi2x1 = vld1q_f32(i2); i2 += 4;
115 const float32x4_t vi3x1 = vld1q_f32(i3); i3 += 4;
116 const float32x4_t vi4x1 = vld1q_f32(i4); i4 += 4;
117
118 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
119 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
120 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
121 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
122
123 const float32x4_t vk10c0 = vld1q_f32(w + 8);
124
125 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
126 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
127 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
128 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
129
130 const float32x4_t vk20c0 = vld1q_f32(w + 12);
131
132 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
133 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
134 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
135 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
136
137 const float32x4_t vk00c1 = vld1q_f32(w + 16);
138
139 // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
140 const float32x4_t vi0x2 = vld1q_f32(i0); i0 += 4;
141 const float32x4_t vi1x2 = vld1q_f32(i1); i1 += 4;
142 const float32x4_t vi2x2 = vld1q_f32(i2); i2 += 4;
143 const float32x4_t vi3x2 = vld1q_f32(i3); i3 += 4;
144 const float32x4_t vi4x2 = vld1q_f32(i4); i4 += 4;
145
146 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
147 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
148 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
149 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
150
151 const float32x4_t vk10c1 = vld1q_f32(w + 20);
152
153 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
154 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
155 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
156 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
157
158 const float32x4_t vk20c1 = vld1q_f32(w + 24);
159
160 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
161 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
162 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
163 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
164
165 const float32x4_t vk00c2 = vld1q_f32(w + 28);
166
167 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
168 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
169 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
170 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
171
172 const float32x4_t vk10c2 = vld1q_f32(w + 32);
173
174 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
175 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
176 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
177 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
178
179 const float32x4_t vk20c2 = vld1q_f32(w + 36);
180
181 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
182 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
183 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
184 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
185
186 const float32x4_t vk01c0 = vld1q_f32(w + 40);
187
188 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
189 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
190 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
191 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
192
193 const float32x4_t vk11c0 = vld1q_f32(w + 44);
194
195 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
196 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
197 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
198 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
199
200 const float32x4_t vk21c0 = vld1q_f32(w + 48);
201
202 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
203 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
204 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
205 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
206
207 const float32x4_t vk01c1 = vld1q_f32(w + 52);
208
209 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
210 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
211 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
212 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
213
214 const float32x4_t vk11c1 = vld1q_f32(w + 56);
215
216 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
217 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
218 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
219 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
220
221 const float32x4_t vk21c1 = vld1q_f32(w + 60);
222
223 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
224 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
225 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
226 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
227
228 const float32x4_t vk01c2 = vld1q_f32(w + 64);
229
230 // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 )
231 const float32x4_t vi0x3 = vld1q_f32(i0); i0 += 4;
232 const float32x4_t vi1x3 = vld1q_f32(i1); i1 += 4;
233 const float32x4_t vi2x3 = vld1q_f32(i2); i2 += 4;
234 const float32x4_t vi3x3 = vld1q_f32(i3); i3 += 4;
235 const float32x4_t vi4x3 = vld1q_f32(i4); i4 += 4;
236
237 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
238 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
239 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
240 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
241
242 const float32x4_t vk11c2 = vld1q_f32(w + 68);
243
244 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
245 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
246 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
247 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
248
249 const float32x4_t vk21c2 = vld1q_f32(w + 72);
250
251 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
252 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
253 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
254 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
255
256 const float32x4_t vk02c0 = vld1q_f32(w + 76);
257
258 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
259 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
260 vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c0, vi0x3, 1);
261 vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c0, vi2x3, 1);
262
263 const float32x4_t vk12c0 = vld1q_f32(w + 80);
264
265 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
266 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
267 vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c0, vi1x3, 1);
268 vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c0, vi3x3, 1);
269
270 const float32x4_t vk22c0 = vld1q_f32(w + 84);
271
272 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
273 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
274 vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c0, vi2x3, 1);
275 vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c0, vi4x3, 1);
276
277 const float32x4_t vk02c1 = vld1q_f32(w + 88);
278
279 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
280 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
281 vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c1, vi0x3, 2);
282 vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c1, vi2x3, 2);
283
284 const float32x4_t vk12c1 = vld1q_f32(w + 92);
285
286 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
287 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
288 vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c1, vi1x3, 2);
289 vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c1, vi3x3, 2);
290
291 const float32x4_t vk22c1 = vld1q_f32(w + 96);
292
293 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
294 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
295 vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c1, vi2x3, 2);
296 vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c1, vi4x3, 2);
297
298 const float32x4_t vk02c2 = vld1q_f32(w + 100);
299
300 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
301 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
302 vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c2, vi0x3, 3);
303 vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c2, vi2x3, 3);
304
305 const float32x4_t vk12c2 = vld1q_f32(w + 104);
306
307 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
308 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
309 vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c2, vi1x3, 3);
310 vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c2, vi3x3, 3);
311
312 const float32x4_t vk22c2 = vld1q_f32(w + 108);
313
314 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
315 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
316 vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c2, vi2x3, 3);
317 vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c2, vi4x3, 3);
318
319 vi0x0 = vi0x3;
320 vi1x0 = vi1x3;
321 vi2x0 = vi2x3;
322 vi3x0 = vi3x3;
323 vi4x0 = vi4x3;
324
325 vo0x0 = vmaxq_f32(vo0x0, vmin);
326 vo1x0 = vmaxq_f32(vo1x0, vmin);
327 vo0x1 = vmaxq_f32(vo0x1, vmin);
328 vo1x1 = vmaxq_f32(vo1x1, vmin);
329
330 vo0x0 = vminq_f32(vo0x0, vmax);
331 vo1x0 = vminq_f32(vo1x0, vmax);
332 vo0x1 = vminq_f32(vo0x1, vmax);
333 vo1x1 = vminq_f32(vo1x1, vmax);
334
335 const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1);
336 const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1);
337 const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1);
338 const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1);
339
340 // Always 2+ output width elements remaining
341 vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2;
342 vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2;
343 vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2;
344 vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2;
345
346 vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2;
347 vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2;
348 vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2;
349 vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2;
350 }
351 assert(iw < 4);
352 if XNN_UNLIKELY(iw != 0) {
353 float32x4_t vo0x0 = vld1q_f32(w);
354 float32x4_t vo1x0 = vo0x0;
355 float32x4_t vo0x1 = vo0x0;
356 float32x4_t vo1x1 = vo0x0;
357
358 const float32x4_t vk00c0 = vld1q_f32(w + 4);
359
360 // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
361 float32x4_t vi0x1 = vld1q_f32(i0);
362 float32x4_t vi1x1 = vld1q_f32(i1);
363 float32x4_t vi2x1 = vld1q_f32(i2);
364 float32x4_t vi3x1 = vld1q_f32(i3);
365 float32x4_t vi4x1 = vld1q_f32(i4);
366
367 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
368 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
369 if (iw > 2) {
370 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
371 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
372 }
373
374 const float32x4_t vk10c0 = vld1q_f32(w + 8);
375
376 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
377 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
378 if (iw > 2) {
379 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
380 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
381 }
382
383 const float32x4_t vk20c0 = vld1q_f32(w + 12);
384
385 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
386 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
387 if (iw > 2) {
388 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
389 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
390 }
391
392 const float32x4_t vk00c1 = vld1q_f32(w + 16);
393
394 float32x4_t vi0x2 = vmovq_n_f32(0.0f);
395 float32x4_t vi1x2 = vmovq_n_f32(0.0f);
396 float32x4_t vi2x2 = vmovq_n_f32(0.0f);
397 float32x4_t vi3x2 = vmovq_n_f32(0.0f);
398 float32x4_t vi4x2 = vmovq_n_f32(0.0f);
399 if (iw >= 2) {
400 // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
401 vi0x2 = vld1q_f32(i0 + 4);
402 vi1x2 = vld1q_f32(i1 + 4);
403 vi2x2 = vld1q_f32(i2 + 4);
404 vi3x2 = vld1q_f32(i3 + 4);
405 vi4x2 = vld1q_f32(i4 + 4);
406 }
407
408 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
409 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
410 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
411 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
412
413 const float32x4_t vk10c1 = vld1q_f32(w + 20);
414
415 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
416 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
417 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
418 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
419
420 const float32x4_t vk20c1 = vld1q_f32(w + 24);
421
422 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
423 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
424 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
425 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
426
427 const float32x4_t vk00c2 = vld1q_f32(w + 28);
428
429 vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
430 vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
431 vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
432 vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
433
434 const float32x4_t vk10c2 = vld1q_f32(w + 32);
435
436 vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
437 vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
438 vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
439 vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
440
441 const float32x4_t vk20c2 = vld1q_f32(w + 36);
442
443 vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
444 vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
445 vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
446 vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
447
448 const float32x4_t vk01c0 = vld1q_f32(w + 40);
449
450 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
451 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
452 if (iw > 2) {
453 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
454 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
455 }
456
457 const float32x4_t vk11c0 = vld1q_f32(w + 44);
458
459 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
460 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
461 if (iw > 2) {
462 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
463 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
464 }
465
466 const float32x4_t vk21c0 = vld1q_f32(w + 48);
467
468 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
469 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
470 if (iw > 2) {
471 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
472 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
473 }
474
475 const float32x4_t vk01c1 = vld1q_f32(w + 52);
476
477 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
478 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
479 if (iw > 2) {
480 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
481 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
482 }
483
484 const float32x4_t vk11c1 = vld1q_f32(w + 56);
485
486 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
487 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
488 if (iw > 2) {
489 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
490 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
491 }
492
493 const float32x4_t vk21c1 = vld1q_f32(w + 60);
494
495 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
496 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
497 if (iw > 2) {
498 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
499 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
500 }
501
502 const float32x4_t vk01c2 = vld1q_f32(w + 64);
503
504 float32x4_t vi0x3 = vmovq_n_f32(0.0f);
505 float32x4_t vi1x3 = vmovq_n_f32(0.0f);
506 float32x4_t vi2x3 = vmovq_n_f32(0.0f);
507 float32x4_t vi3x3 = vmovq_n_f32(0.0f);
508 float32x4_t vi4x3 = vmovq_n_f32(0.0f);
509 if (iw > 2) {
510 // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 )
511 vi0x3 = vld1q_lane_f32(i0 + 8, vi0x3, 0);
512 vi1x3 = vld1q_lane_f32(i1 + 8, vi1x3, 0);
513 vi2x3 = vld1q_lane_f32(i2 + 8, vi2x3, 0);
514 vi3x3 = vld1q_lane_f32(i3 + 8, vi3x3, 0);
515 vi4x3 = vld1q_lane_f32(i4 + 8, vi4x3, 0);
516 }
517
518 vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
519 vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
520 vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
521 vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
522
523 const float32x4_t vk11c2 = vld1q_f32(w + 68);
524
525 vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
526 vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
527 vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
528 vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
529
530 const float32x4_t vk21c2 = vld1q_f32(w + 72);
531
532 vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
533 vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
534 vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
535 vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
536
537 if (iw >= 2) {
538 const float32x4_t vk02c0 = vld1q_f32(w + 76);
539
540 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
541 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
542
543 const float32x4_t vk12c0 = vld1q_f32(w + 80);
544
545 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
546 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
547
548 const float32x4_t vk22c0 = vld1q_f32(w + 84);
549
550 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
551 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
552
553 const float32x4_t vk02c1 = vld1q_f32(w + 88);
554
555 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
556 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
557
558 const float32x4_t vk12c1 = vld1q_f32(w + 92);
559
560 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
561 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
562
563 const float32x4_t vk22c1 = vld1q_f32(w + 96);
564
565 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
566 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
567
568 const float32x4_t vk02c2 = vld1q_f32(w + 100);
569
570 vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
571 vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
572
573 const float32x4_t vk12c2 = vld1q_f32(w + 104);
574
575 vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
576 vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
577
578 const float32x4_t vk22c2 = vld1q_f32(w + 108);
579
580 vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
581 vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
582 }
583
584 vo0x0 = vmaxq_f32(vo0x0, vmin);
585 vo1x0 = vmaxq_f32(vo1x0, vmin);
586 vo0x1 = vmaxq_f32(vo0x1, vmin);
587 vo1x1 = vmaxq_f32(vo1x1, vmin);
588
589 vo0x0 = vminq_f32(vo0x0, vmax);
590 vo1x0 = vminq_f32(vo1x0, vmax);
591 vo0x1 = vminq_f32(vo0x1, vmax);
592 vo1x1 = vminq_f32(vo1x1, vmax);
593
594 if (iw == 3) {
595 // Exactly 2 output width elements remaining
596 const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1);
597 const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1);
598 const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1);
599 const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1);
600
601 vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2;
602 vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2;
603 vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2;
604 vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2;
605
606 vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2;
607 vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2;
608 vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2;
609 vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2;
610 } else {
611 // Exactly 1 output width element remaining
612
613 vst1q_lane_f32(o1c0, vo1x0, 0); o1c0 += 1;
614 vst1q_lane_f32(o1c1, vo1x0, 1); o1c1 += 1;
615 vst1q_lane_f32(o1c2, vo1x0, 2); o1c2 += 1;
616 vst1q_lane_f32(o1c3, vo1x0, 3); o1c3 += 1;
617
618 vst1q_lane_f32(o0c0, vo0x0, 0); o0c0 += 1;
619 vst1q_lane_f32(o0c1, vo0x0, 1); o0c1 += 1;
620 vst1q_lane_f32(o0c2, vo0x0, 2); o0c2 += 1;
621 vst1q_lane_f32(o0c3, vo0x0, 3); o0c3 += 1;
622 }
623 }
624 // Move output pointers back to the position of the first pixel in a row,
625 // and forward to the next block of output channels.
626 o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment);
627 o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment);
628 o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment);
629 o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment);
630 o1c0 = (float*) ((uintptr_t) o1c0 + output_channel_increment);
631 o1c1 = (float*) ((uintptr_t) o1c1 + output_channel_increment);
632 o1c2 = (float*) ((uintptr_t) o1c2 + output_channel_increment);
633 o1c3 = (float*) ((uintptr_t) o1c3 + output_channel_increment);
634 // Revert input pointers to the position of the first pixel in a row
635 i0 = (const float*) ((uintptr_t) i0 - input_width_increment);
636 i1 = (const float*) ((uintptr_t) i1 - input_width_increment);
637 i2 = (const float*) ((uintptr_t) i2 - input_width_increment);
638 i3 = (const float*) ((uintptr_t) i3 - input_width_increment);
639 i4 = (const float*) ((uintptr_t) i4 - input_width_increment);
640 // Move to the block of weights for the next 4 output channels
641 w += 112;
642 c = doz(c, 4);
643 } while (c != 0);
644 // Move output pointers forward to the next two rows
645 output0 = (float*) ((uintptr_t) output1 + output_height_stride);
646 output1 = (float*) ((uintptr_t) output0 + output_height_stride);
647 // Move input pointers forward to the next four rows
648 i0 = i4;
649 i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
650 i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
651 i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
652 i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
653 }
654 }
655