• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <xnnpack/common.h>
14 #include <xnnpack/gavgpool.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_q8_gavgpool_ukernel_mp7p7q__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,int32_t * buffer,uint8_t * output,const union xnn_q8_avgpool_params params[restrict static1])18 void xnn_q8_gavgpool_ukernel_mp7p7q__neon(
19     size_t m,
20     size_t n,
21     const uint8_t* input,
22     size_t input_stride,
23     const uint8_t* zero,
24     int32_t* buffer,
25     uint8_t* output,
26     const union xnn_q8_avgpool_params params[restrict static 1])
27 {
28   assert(m > 7);
29   assert(n != 0);
30 
31   const uint8_t* i0 = input;
32   const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
33   const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
34   const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
35   const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
36   const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
37   const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
38   const size_t packed_n = round_up_po2(n, 8);
39   const size_t input_increment = 7 * input_stride - packed_n;
40   const int32x4_t vbias = vld1q_dup_s32(&params->neon.bias);
41 
42   int32_t* acc = buffer;
43   for (size_t k = 0; k < n; k += 8) {
44     const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
45     const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
46     const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
47     const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
48     const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
49     const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
50     const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
51 
52     const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
53     const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
54     const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
55 
56     const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
57     const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
58 
59     const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
60 
61     const int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum));
62     const int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum));
63 
64     vst1q_s32(acc, vacc_lo); acc += 4;
65     vst1q_s32(acc, vacc_hi); acc += 4;
66   }
67   for (m -= 7; m > 7; m -= 7) {
68     acc = buffer;
69 
70     i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
71     i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
72     i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
73     i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
74     i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
75     i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
76     i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
77 
78     for (size_t k = 0; k < n; k += 8) {
79       const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
80       const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
81       const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
82       const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
83       const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
84       const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
85       const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
86       const int32x4_t vacc_lo = vld1q_s32(acc);
87       const int32x4_t vacc_hi = vld1q_s32(acc + 4);
88 
89       const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
90       const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
91       const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
92 
93       const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
94       const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
95 
96       const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
97 
98       vst1q_s32(acc, vaddw_s16(vacc_lo, vget_low_s16(vsum))); acc += 4;
99       vst1q_s32(acc, vaddw_s16(vacc_hi, vget_high_s16(vsum))); acc += 4;
100     }
101   }
102 
103 #if XNN_ARCH_ARM64
104   const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
105 #else
106   const int32x2_t vmultiplier = vld1_dup_s32(&params->neon.multiplier);
107 #endif
108   const int64x2_t vleft_shift = vld1q_dup_s64(&params->neon.left_shift);
109   const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
110   const uint8x8_t voutput_min = vld1_dup_u8(&params->neon.output_min);
111   const uint8x8_t voutput_max = vld1_dup_u8(&params->neon.output_max);
112 
113   i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
114   i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
115   if (m < 2) {
116     i1 = zero;
117   }
118   i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
119   if (m <= 2) {
120     i2 = zero;
121   }
122   i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
123   if (m < 4) {
124     i3 = zero;
125   }
126   i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
127   if (m <= 4) {
128     i4 = zero;
129   }
130   i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
131   if (m < 6) {
132     i5 = zero;
133   }
134   i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
135   if (m <= 6) {
136     i6 = zero;
137   }
138 
139   acc = buffer;
140   while (n >= 8) {
141     const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
142     const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
143     const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
144     const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
145     const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
146     const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
147     const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
148     int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
149     int32x4_t vacc_hi = vld1q_s32(acc); acc += 4;
150 
151     const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
152     const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
153     const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
154 
155     const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
156     const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
157 
158     const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
159     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
160     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
161 
162     const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
163     const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
164 
165 #if XNN_ARCH_ARM64
166     const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
167     const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
168     const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
169     const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
170 
171     const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
172     const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
173     const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
174     const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
175 #else
176     const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
177     const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
178     const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
179     const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
180 
181     const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
182     const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
183     const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
184     const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
185 #endif
186 
187     const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
188     const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
189     const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
190     const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
191 
192 #if XNN_ARCH_ARM64
193     vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
194     vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
195 
196     const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
197 #else
198     vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
199     vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
200 
201     const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
202 #endif
203 
204     uint8x8_t vout = vqmovun_s16(vacc);
205     vout = vmax_u8(vout, voutput_min);
206     vout = vmin_u8(vout, voutput_max);
207 
208     vst1_u8(output, vout); output += 8;
209 
210     n -= 8;
211   }
212   if (n != 0) {
213     const uint8x8_t vi0 = vld1_u8(i0);
214     const uint8x8_t vi1 = vld1_u8(i1);
215     const uint8x8_t vi2 = vld1_u8(i2);
216     const uint8x8_t vi3 = vld1_u8(i3);
217     const uint8x8_t vi4 = vld1_u8(i4);
218     const uint8x8_t vi5 = vld1_u8(i5);
219     const uint8x8_t vi6 = vld1_u8(i6);
220     int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
221     int32x4_t vacc_hi = vld1q_s32(acc);
222 
223     const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
224     const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
225     const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
226 
227     const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
228     const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
229 
230     const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
231     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
232     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
233 
234     const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
235     const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
236 
237 #if XNN_ARCH_ARM64
238     const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
239     const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
240     const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
241     const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
242 
243     const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
244     const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
245     const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
246     const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
247 #else
248     const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
249     const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
250     const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
251     const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
252 
253     const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
254     const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
255     const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
256     const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
257 #endif
258 
259     const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
260     const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
261     const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
262     const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
263 
264 #if XNN_ARCH_ARM64
265     vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
266     vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
267 
268     const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
269 #else
270     vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
271     vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
272 
273     const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
274 #endif
275 
276     uint8x8_t vout = vqmovun_s16(vacc);
277     vout = vmax_u8(vout, voutput_min);
278     vout = vmin_u8(vout, voutput_max);
279 
280     if (n & 4) {
281       vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
282       vout = vext_u8(vout, vout, 4);
283     }
284     if (n & 2) {
285       vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
286       vout = vext_u8(vout, vout, 2);
287     }
288     if (n & 1) {
289       vst1_lane_u8(output, vout, 0);
290     }
291   }
292 }
293