1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <xnnpack/common.h>
14 #include <xnnpack/gavgpool.h>
15 #include <xnnpack/math.h>
16
17
xnn_q8_gavgpool_ukernel_mp7p7q__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,int32_t * buffer,uint8_t * output,const union xnn_q8_avgpool_params params[restrict static1])18 void xnn_q8_gavgpool_ukernel_mp7p7q__neon(
19 size_t m,
20 size_t n,
21 const uint8_t* input,
22 size_t input_stride,
23 const uint8_t* zero,
24 int32_t* buffer,
25 uint8_t* output,
26 const union xnn_q8_avgpool_params params[restrict static 1])
27 {
28 assert(m > 7);
29 assert(n != 0);
30
31 const uint8_t* i0 = input;
32 const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
33 const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
34 const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
35 const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
36 const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
37 const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
38 const size_t packed_n = round_up_po2(n, 8);
39 const size_t input_increment = 7 * input_stride - packed_n;
40 const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
41
42 int32_t* acc = buffer;
43 for (size_t k = 0; k < n; k += 8) {
44 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
45 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
46 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
47 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
48 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
49 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
50 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
51
52 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
53 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
54 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
55
56 const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
57 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
58
59 const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
60
61 const int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum));
62 const int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum));
63
64 vst1q_s32(acc, vacc_lo); acc += 4;
65 vst1q_s32(acc, vacc_hi); acc += 4;
66 }
67 for (m -= 7; m > 7; m -= 7) {
68 acc = buffer;
69
70 i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
71 i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
72 i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
73 i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
74 i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
75 i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
76 i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
77
78 for (size_t k = 0; k < n; k += 8) {
79 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
80 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
81 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
82 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
83 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
84 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
85 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
86 const int32x4_t vacc_lo = vld1q_s32(acc);
87 const int32x4_t vacc_hi = vld1q_s32(acc + 4);
88
89 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
90 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
91 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
92
93 const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
94 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
95
96 const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
97
98 vst1q_s32(acc, vaddw_s16(vacc_lo, vget_low_s16(vsum))); acc += 4;
99 vst1q_s32(acc, vaddw_s16(vacc_hi, vget_high_s16(vsum))); acc += 4;
100 }
101 }
102
103 #if XNN_ARCH_ARM64
104 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
105 #else
106 const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
107 #endif
108 const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
109 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
110 const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
111 const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
112
113 i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
114 i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
115 if (m < 2) {
116 i1 = zero;
117 }
118 i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
119 if (m <= 2) {
120 i2 = zero;
121 }
122 i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
123 if (m < 4) {
124 i3 = zero;
125 }
126 i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
127 if (m <= 4) {
128 i4 = zero;
129 }
130 i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
131 if (m < 6) {
132 i5 = zero;
133 }
134 i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
135 if (m <= 6) {
136 i6 = zero;
137 }
138
139 acc = buffer;
140 while (n >= 8) {
141 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
142 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
143 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
144 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
145 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
146 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
147 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
148 int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
149 int32x4_t vacc_hi = vld1q_s32(acc); acc += 4;
150
151 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
152 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
153 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
154
155 const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
156 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
157
158 const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
159 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
160 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
161
162 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
163 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
164
165 #if XNN_ARCH_ARM64
166 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
167 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
168 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
169 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
170
171 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
172 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
173 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
174 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
175 #else
176 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
177 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
178 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
179 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
180
181 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
182 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
183 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
184 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
185 #endif
186
187 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
188 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
189 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
190 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
191
192 #if XNN_ARCH_ARM64
193 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
194 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
195
196 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
197 #else
198 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
199 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
200
201 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
202 #endif
203
204 uint8x8_t vout = vqmovun_s16(vacc);
205 vout = vmax_u8(vout, voutput_min);
206 vout = vmin_u8(vout, voutput_max);
207
208 vst1_u8(output, vout); output += 8;
209
210 n -= 8;
211 }
212 if (n != 0) {
213 const uint8x8_t vi0 = vld1_u8(i0);
214 const uint8x8_t vi1 = vld1_u8(i1);
215 const uint8x8_t vi2 = vld1_u8(i2);
216 const uint8x8_t vi3 = vld1_u8(i3);
217 const uint8x8_t vi4 = vld1_u8(i4);
218 const uint8x8_t vi5 = vld1_u8(i5);
219 const uint8x8_t vi6 = vld1_u8(i6);
220 int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
221 int32x4_t vacc_hi = vld1q_s32(acc);
222
223 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
224 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
225 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
226
227 const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
228 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
229
230 const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
231 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
232 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
233
234 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
235 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
236
237 #if XNN_ARCH_ARM64
238 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
239 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
240 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
241 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
242
243 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
244 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
245 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
246 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
247 #else
248 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
249 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
250 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
251 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
252
253 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
254 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
255 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
256 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
257 #endif
258
259 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
260 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
261 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
262 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
263
264 #if XNN_ARCH_ARM64
265 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
266 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
267
268 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
269 #else
270 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
271 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
272
273 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
274 #endif
275
276 uint8x8_t vout = vqmovun_s16(vacc);
277 vout = vmax_u8(vout, voutput_min);
278 vout = vmin_u8(vout, voutput_max);
279
280 if (n & 4) {
281 vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
282 vout = vext_u8(vout, vout, 4);
283 }
284 if (n & 2) {
285 vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
286 vout = vext_u8(vout, vout, 2);
287 }
288 if (n & 1) {
289 vst1_lane_u8(output, vout, 0);
290 }
291 }
292 }
293