• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <xnnpack/avgpool.h>
14 #include <xnnpack/common.h>
15 
16 
xnn_qu8_avgpool_minmax_ukernel_9p8x__neon_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const uint8_t ** input,size_t input_offset,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_qu8_avgpool_minmax_ukernel_9p8x__neon_c8(
18     size_t output_pixels,
19     size_t kernel_elements,
20     size_t channels,
21     const uint8_t** input,
22     size_t input_offset,
23     const uint8_t* zero,
24     int32_t* buffer,
25     uint8_t* output,
26     size_t input_increment,
27     size_t output_increment,
28     const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29 {
30   assert(output_pixels != 0);
31   assert(kernel_elements > 9);
32   assert(channels != 0);
33 
34   const int32x4_t vbias = vld1q_dup_s32(&params->neon.bias);
35 #if XNN_ARCH_ARM64
36   const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
37 #else
38   const int32x2_t vmultiplier = vld1_dup_s32(&params->neon.multiplier);
39 #endif
40   const int64x2_t vleft_shift = vld1q_dup_s64(&params->neon.left_shift);
41   const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
42   const uint8x8_t voutput_min = vld1_dup_u8(&params->neon.output_min);
43   const uint8x8_t voutput_max = vld1_dup_u8(&params->neon.output_max);
44 
45   do {
46     {
47       const uint8_t* i0 = *input++;
48       assert(i0 != NULL);
49       if XNN_UNPREDICTABLE(i0 != zero) {
50         i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
51       }
52       const uint8_t* i1 = *input++;
53       assert(i1 != NULL);
54       if XNN_UNPREDICTABLE(i1 != zero) {
55         i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
56       }
57       const uint8_t* i2 = *input++;
58       assert(i2 != NULL);
59       if XNN_UNPREDICTABLE(i2 != zero) {
60         i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
61       }
62       const uint8_t* i3 = *input++;
63       assert(i3 != NULL);
64       if XNN_UNPREDICTABLE(i3 != zero) {
65         i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
66       }
67       const uint8_t* i4 = *input++;
68       assert(i4 != NULL);
69       if XNN_UNPREDICTABLE(i4 != zero) {
70         i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
71       }
72       const uint8_t* i5 = *input++;
73       assert(i5 != NULL);
74       if XNN_UNPREDICTABLE(i5 != zero) {
75         i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
76       }
77       const uint8_t* i6 = *input++;
78       assert(i6 != NULL);
79       if XNN_UNPREDICTABLE(i6 != zero) {
80         i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
81       }
82       const uint8_t* i7 = *input++;
83       assert(i7 != NULL);
84       if XNN_UNPREDICTABLE(i7 != zero) {
85         i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
86       }
87       const uint8_t* i8 = *input++;
88       assert(i8 != NULL);
89       if XNN_UNPREDICTABLE(i8 != zero) {
90         i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
91       }
92 
93       int32_t* b = buffer;
94       for (size_t c = 0; c < channels; c += 8) {
95         const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
96         const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
97         const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
98         const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
99         const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
100         const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
101         const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
102         const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
103         const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
104 
105         const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
106         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
107         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
108         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
109 
110         const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
111         const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
112         const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
113 
114         const int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
115         const int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
116 
117         vst1q_s32(b, vacc_lo); b += 4;
118         vst1q_s32(b, vacc_hi); b += 4;
119       }
120     }
121 
122     size_t k = kernel_elements;
123     for (k -= 9; k > 8; k -= 8) {
124       const uint8_t* i0 = *input++;
125       assert(i0 != NULL);
126       if XNN_UNPREDICTABLE(i0 != zero) {
127         i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
128       }
129       const uint8_t* i1 = *input++;
130       assert(i1 != NULL);
131       if XNN_UNPREDICTABLE(i1 != zero) {
132         i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
133       }
134       const uint8_t* i2 = *input++;
135       assert(i2 != NULL);
136       if XNN_UNPREDICTABLE(i2 != zero) {
137         i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
138       }
139       const uint8_t* i3 = *input++;
140       assert(i3 != NULL);
141       if XNN_UNPREDICTABLE(i3 != zero) {
142         i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
143       }
144       const uint8_t* i4 = *input++;
145       assert(i4 != NULL);
146       if XNN_UNPREDICTABLE(i4 != zero) {
147         i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
148       }
149       const uint8_t* i5 = *input++;
150       assert(i5 != NULL);
151       if XNN_UNPREDICTABLE(i5 != zero) {
152         i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
153       }
154       const uint8_t* i6 = *input++;
155       assert(i6 != NULL);
156       if XNN_UNPREDICTABLE(i6 != zero) {
157         i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
158       }
159       const uint8_t* i7 = *input++;
160       assert(i7 != NULL);
161       if XNN_UNPREDICTABLE(i7 != zero) {
162         i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
163       }
164 
165       int32_t* b = buffer;
166       for (size_t c = 0; c < channels; c += 8) {
167         const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
168         const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
169         const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
170         const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
171         const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
172         const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
173         const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
174         const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
175         int32x4_t vacc_lo = vld1q_s32(b);
176         int32x4_t vacc_hi = vld1q_s32(b + 4);
177 
178         const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
179         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
180         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
181         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
182 
183         const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
184         const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
185         const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
186 
187         vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
188         vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
189 
190         vst1q_s32(b, vacc_lo); b += 4;
191         vst1q_s32(b, vacc_hi); b += 4;
192       }
193     }
194 
195     {
196       const uint8_t* i0 = input[0];
197       assert(i0 != NULL);
198       const uint8_t* i1 = input[1];
199       const uint8_t* i2 = input[2];
200       const uint8_t* i3 = input[3];
201       const uint8_t* i4 = input[4];
202       const uint8_t* i5 = input[5];
203       const uint8_t* i6 = input[6];
204       const uint8_t* i7 = input[7];
205       input = (const uint8_t**) ((uintptr_t) input + input_increment);
206       if (k < 2) {
207         i1 = zero;
208       }
209       assert(i1 != NULL);
210       if (k <= 2) {
211         i2 = zero;
212       }
213       assert(i2 != NULL);
214       if (k < 4) {
215         i3 = zero;
216       }
217       assert(i3 != NULL);
218       if (k <= 4) {
219         i4 = zero;
220       }
221       assert(i4 != NULL);
222       if (k < 6) {
223         i5 = zero;
224       }
225       assert(i5 != NULL);
226       if (k <= 6) {
227         i6 = zero;
228       }
229       assert(i6 != NULL);
230       if (k < 8) {
231         i7 = zero;
232       }
233       assert(i7 != NULL);
234       if XNN_UNPREDICTABLE(i0 != zero) {
235         i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
236       }
237       if XNN_UNPREDICTABLE(i1 != zero) {
238         i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
239       }
240       if XNN_UNPREDICTABLE(i2 != zero) {
241         i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
242       }
243       if XNN_UNPREDICTABLE(i3 != zero) {
244         i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
245       }
246       if XNN_UNPREDICTABLE(i4 != zero) {
247         i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
248       }
249       if XNN_UNPREDICTABLE(i5 != zero) {
250         i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
251       }
252       if XNN_UNPREDICTABLE(i6 != zero) {
253         i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
254       }
255       if XNN_UNPREDICTABLE(i7 != zero) {
256         i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
257       }
258 
259       size_t c = channels;
260       int32_t* b = buffer;
261       while (c >= 8) {
262         const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
263         const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
264         const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
265         const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
266         const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
267         const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
268         const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
269         const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
270         int32x4_t vacc_lo = vld1q_s32(b); b += 4;
271         int32x4_t vacc_hi = vld1q_s32(b); b += 4;
272 
273         const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
274         const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
275         const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
276         const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
277 
278         const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
279         const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
280         const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
281 
282         vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
283         vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
284 
285         const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
286         const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
287 
288 #if XNN_ARCH_ARM64
289         const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
290         const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
291         const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
292         const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
293 
294         const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
295         const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
296         const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
297         const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
298 #else
299         const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
300         const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
301         const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
302         const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
303 
304         const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
305         const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
306         const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
307         const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
308 #endif
309 
310         const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
311         const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
312         const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
313         const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
314 
315 #if XNN_ARCH_ARM64
316         vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
317         vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
318 
319         const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
320 #else
321         vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
322         vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
323 
324         const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
325 #endif
326 
327         uint8x8_t vout = vqmovun_s16(vacc);
328         vout = vmax_u8(vout, voutput_min);
329         vout = vmin_u8(vout, voutput_max);
330 
331         vst1_u8(output, vout); output += 8;
332 
333         c -= 8;
334       }
335       if (c != 0) {
336         const uint8x8_t vi0 = vld1_u8(i0);
337         const uint8x8_t vi1 = vld1_u8(i1);
338         const uint8x8_t vi2 = vld1_u8(i2);
339         const uint8x8_t vi3 = vld1_u8(i3);
340         const uint8x8_t vi4 = vld1_u8(i4);
341         const uint8x8_t vi5 = vld1_u8(i5);
342         const uint8x8_t vi6 = vld1_u8(i6);
343         const uint8x8_t vi7 = vld1_u8(i7);
344         int32x4_t vacc_lo = vld1q_s32(b); b += 4;
345         int32x4_t vacc_hi = vld1q_s32(b);
346 
347         const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
348         const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
349         const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
350         const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
351 
352         const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
353         const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
354         const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
355 
356         vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
357         vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
358 
359         const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
360         const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
361 
362 #if XNN_ARCH_ARM64
363         const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
364         const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
365         const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
366         const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
367 
368         const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
369         const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
370         const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
371         const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
372 #else
373         const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
374         const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
375         const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
376         const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
377 
378         const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
379         const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
380         const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
381         const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
382 #endif
383 
384         const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
385         const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
386         const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
387         const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
388 
389 #if XNN_ARCH_ARM64
390         vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
391         vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
392 
393         const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
394 #else
395         vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
396         vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
397 
398         const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
399 #endif
400 
401         uint8x8_t vout = vqmovun_s16(vacc);
402         vout = vmax_u8(vout, voutput_min);
403         vout = vmin_u8(vout, voutput_max);
404 
405         if (c & 4) {
406           vst1_lane_u32((void*) output, vreinterpret_u32_u8(vout), 0); output += 4;
407           vout = vext_u8(vout, vout, 4);
408         }
409         if (c & 2) {
410           vst1_lane_u16((void*) output, vreinterpret_u16_u8(vout), 0); output += 2;
411           vout = vext_u8(vout, vout, 2);
412         }
413         if (c & 1) {
414           vst1_lane_u8(output, vout, 0); output += 1;
415         }
416       }
417     }
418     output = (uint8_t*) ((uintptr_t) output + output_increment);
419   } while (--output_pixels != 0);
420 }
421