1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <xnnpack/avgpool.h>
14 #include <xnnpack/common.h>
15
16
xnn_qu8_avgpool_minmax_ukernel_9p8x__neon_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const uint8_t ** input,size_t input_offset,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_qu8_avgpool_minmax_ukernel_9p8x__neon_c8(
18 size_t output_pixels,
19 size_t kernel_elements,
20 size_t channels,
21 const uint8_t** input,
22 size_t input_offset,
23 const uint8_t* zero,
24 int32_t* buffer,
25 uint8_t* output,
26 size_t input_increment,
27 size_t output_increment,
28 const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29 {
30 assert(output_pixels != 0);
31 assert(kernel_elements > 9);
32 assert(channels != 0);
33
34 const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
35 #if XNN_ARCH_ARM64
36 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
37 #else
38 const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
39 #endif
40 const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
41 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
42 const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
43 const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
44
45 do {
46 {
47 const uint8_t* i0 = *input++;
48 assert(i0 != NULL);
49 if XNN_UNPREDICTABLE(i0 != zero) {
50 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
51 }
52 const uint8_t* i1 = *input++;
53 assert(i1 != NULL);
54 if XNN_UNPREDICTABLE(i1 != zero) {
55 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
56 }
57 const uint8_t* i2 = *input++;
58 assert(i2 != NULL);
59 if XNN_UNPREDICTABLE(i2 != zero) {
60 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
61 }
62 const uint8_t* i3 = *input++;
63 assert(i3 != NULL);
64 if XNN_UNPREDICTABLE(i3 != zero) {
65 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
66 }
67 const uint8_t* i4 = *input++;
68 assert(i4 != NULL);
69 if XNN_UNPREDICTABLE(i4 != zero) {
70 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
71 }
72 const uint8_t* i5 = *input++;
73 assert(i5 != NULL);
74 if XNN_UNPREDICTABLE(i5 != zero) {
75 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
76 }
77 const uint8_t* i6 = *input++;
78 assert(i6 != NULL);
79 if XNN_UNPREDICTABLE(i6 != zero) {
80 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
81 }
82 const uint8_t* i7 = *input++;
83 assert(i7 != NULL);
84 if XNN_UNPREDICTABLE(i7 != zero) {
85 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
86 }
87 const uint8_t* i8 = *input++;
88 assert(i8 != NULL);
89 if XNN_UNPREDICTABLE(i8 != zero) {
90 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
91 }
92
93 int32_t* b = buffer;
94 for (size_t c = 0; c < channels; c += 8) {
95 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
96 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
97 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
98 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
99 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
100 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
101 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
102 const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
103 const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
104
105 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
106 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
107 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
108 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
109
110 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
111 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
112 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
113
114 const int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
115 const int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
116
117 vst1q_s32(b, vacc_lo); b += 4;
118 vst1q_s32(b, vacc_hi); b += 4;
119 }
120 }
121
122 size_t k = kernel_elements;
123 for (k -= 9; k > 8; k -= 8) {
124 const uint8_t* i0 = *input++;
125 assert(i0 != NULL);
126 if XNN_UNPREDICTABLE(i0 != zero) {
127 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
128 }
129 const uint8_t* i1 = *input++;
130 assert(i1 != NULL);
131 if XNN_UNPREDICTABLE(i1 != zero) {
132 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
133 }
134 const uint8_t* i2 = *input++;
135 assert(i2 != NULL);
136 if XNN_UNPREDICTABLE(i2 != zero) {
137 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
138 }
139 const uint8_t* i3 = *input++;
140 assert(i3 != NULL);
141 if XNN_UNPREDICTABLE(i3 != zero) {
142 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
143 }
144 const uint8_t* i4 = *input++;
145 assert(i4 != NULL);
146 if XNN_UNPREDICTABLE(i4 != zero) {
147 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
148 }
149 const uint8_t* i5 = *input++;
150 assert(i5 != NULL);
151 if XNN_UNPREDICTABLE(i5 != zero) {
152 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
153 }
154 const uint8_t* i6 = *input++;
155 assert(i6 != NULL);
156 if XNN_UNPREDICTABLE(i6 != zero) {
157 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
158 }
159 const uint8_t* i7 = *input++;
160 assert(i7 != NULL);
161 if XNN_UNPREDICTABLE(i7 != zero) {
162 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
163 }
164
165 int32_t* b = buffer;
166 for (size_t c = 0; c < channels; c += 8) {
167 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
168 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
169 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
170 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
171 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
172 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
173 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
174 const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
175 int32x4_t vacc_lo = vld1q_s32(b);
176 int32x4_t vacc_hi = vld1q_s32(b + 4);
177
178 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
179 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
180 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
181 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
182
183 const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
184 const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
185 const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
186
187 vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
188 vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
189
190 vst1q_s32(b, vacc_lo); b += 4;
191 vst1q_s32(b, vacc_hi); b += 4;
192 }
193 }
194
195 {
196 const uint8_t* i0 = input[0];
197 assert(i0 != NULL);
198 const uint8_t* i1 = input[1];
199 const uint8_t* i2 = input[2];
200 const uint8_t* i3 = input[3];
201 const uint8_t* i4 = input[4];
202 const uint8_t* i5 = input[5];
203 const uint8_t* i6 = input[6];
204 const uint8_t* i7 = input[7];
205 input = (const uint8_t**) ((uintptr_t) input + input_increment);
206 if (k < 2) {
207 i1 = zero;
208 }
209 assert(i1 != NULL);
210 if (k <= 2) {
211 i2 = zero;
212 }
213 assert(i2 != NULL);
214 if (k < 4) {
215 i3 = zero;
216 }
217 assert(i3 != NULL);
218 if (k <= 4) {
219 i4 = zero;
220 }
221 assert(i4 != NULL);
222 if (k < 6) {
223 i5 = zero;
224 }
225 assert(i5 != NULL);
226 if (k <= 6) {
227 i6 = zero;
228 }
229 assert(i6 != NULL);
230 if (k < 8) {
231 i7 = zero;
232 }
233 assert(i7 != NULL);
234 if XNN_UNPREDICTABLE(i0 != zero) {
235 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
236 }
237 if XNN_UNPREDICTABLE(i1 != zero) {
238 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
239 }
240 if XNN_UNPREDICTABLE(i2 != zero) {
241 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
242 }
243 if XNN_UNPREDICTABLE(i3 != zero) {
244 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
245 }
246 if XNN_UNPREDICTABLE(i4 != zero) {
247 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
248 }
249 if XNN_UNPREDICTABLE(i5 != zero) {
250 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
251 }
252 if XNN_UNPREDICTABLE(i6 != zero) {
253 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
254 }
255 if XNN_UNPREDICTABLE(i7 != zero) {
256 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
257 }
258
259 size_t c = channels;
260 int32_t* b = buffer;
261 while (c >= 8) {
262 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
263 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
264 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
265 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
266 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
267 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
268 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
269 const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
270 int32x4_t vacc_lo = vld1q_s32(b); b += 4;
271 int32x4_t vacc_hi = vld1q_s32(b); b += 4;
272
273 const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
274 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
275 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
276 const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
277
278 const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
279 const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
280 const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
281
282 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
283 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
284
285 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
286 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
287
288 #if XNN_ARCH_ARM64
289 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
290 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
291 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
292 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
293
294 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
295 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
296 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
297 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
298 #else
299 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
300 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
301 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
302 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
303
304 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
305 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
306 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
307 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
308 #endif
309
310 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
311 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
312 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
313 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
314
315 #if XNN_ARCH_ARM64
316 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
317 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
318
319 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
320 #else
321 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
322 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
323
324 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
325 #endif
326
327 uint8x8_t vout = vqmovun_s16(vacc);
328 vout = vmax_u8(vout, voutput_min);
329 vout = vmin_u8(vout, voutput_max);
330
331 vst1_u8(output, vout); output += 8;
332
333 c -= 8;
334 }
335 if (c != 0) {
336 const uint8x8_t vi0 = vld1_u8(i0);
337 const uint8x8_t vi1 = vld1_u8(i1);
338 const uint8x8_t vi2 = vld1_u8(i2);
339 const uint8x8_t vi3 = vld1_u8(i3);
340 const uint8x8_t vi4 = vld1_u8(i4);
341 const uint8x8_t vi5 = vld1_u8(i5);
342 const uint8x8_t vi6 = vld1_u8(i6);
343 const uint8x8_t vi7 = vld1_u8(i7);
344 int32x4_t vacc_lo = vld1q_s32(b); b += 4;
345 int32x4_t vacc_hi = vld1q_s32(b);
346
347 const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
348 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
349 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
350 const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
351
352 const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
353 const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
354 const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
355
356 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
357 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
358
359 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
360 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
361
362 #if XNN_ARCH_ARM64
363 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
364 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
365 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
366 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
367
368 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
369 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
370 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
371 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
372 #else
373 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
374 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
375 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
376 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
377
378 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
379 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
380 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
381 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
382 #endif
383
384 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
385 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
386 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
387 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
388
389 #if XNN_ARCH_ARM64
390 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
391 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
392
393 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
394 #else
395 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
396 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
397
398 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
399 #endif
400
401 uint8x8_t vout = vqmovun_s16(vacc);
402 vout = vmax_u8(vout, voutput_min);
403 vout = vmin_u8(vout, voutput_max);
404
405 if (c & 4) {
406 vst1_lane_u32((void*) output, vreinterpret_u32_u8(vout), 0); output += 4;
407 vout = vext_u8(vout, vout, 4);
408 }
409 if (c & 2) {
410 vst1_lane_u16((void*) output, vreinterpret_u16_u8(vout), 0); output += 2;
411 vout = vext_u8(vout, vout, 2);
412 }
413 if (c & 1) {
414 vst1_lane_u8(output, vout, 0); output += 1;
415 }
416 }
417 }
418 output = (uint8_t*) ((uintptr_t) output + output_increment);
419 } while (--output_pixels != 0);
420 }
421