1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <xnnpack/avgpool.h>
14 #include <xnnpack/common.h>
15
16
xnn_q8_avgpool_ukernel_up9__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_q8_avgpool_params params[restrict static1])17 void xnn_q8_avgpool_ukernel_up9__neon(
18 size_t n,
19 size_t ks,
20 size_t kc,
21 const uint8_t** input,
22 const uint8_t* zero,
23 uint8_t* output,
24 size_t input_increment,
25 size_t output_increment,
26 const union xnn_q8_avgpool_params params[restrict static 1])
27 {
28 assert(n != 0);
29 assert(ks != 0);
30 assert(ks <= 9);
31 assert(kc != 0);
32
33 const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
34 #if XNN_ARCH_ARM64
35 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
36 #else
37 const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
38 #endif
39 const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
40 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
41 const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
42 const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
43
44 do {
45 const uint8_t* i0 = input[0];
46 const uint8_t* i1 = input[1];
47 const uint8_t* i2 = input[2];
48 const uint8_t* i3 = input[3];
49 const uint8_t* i4 = input[4];
50 const uint8_t* i5 = input[5];
51 const uint8_t* i6 = input[6];
52 const uint8_t* i7 = input[7];
53 const uint8_t* i8 = input[8];
54 input = (const uint8_t**) ((uintptr_t) input + input_increment);
55 if (ks < 2) {
56 i1 = zero;
57 }
58 if (ks <= 2) {
59 i2 = zero;
60 }
61 if (ks < 4) {
62 i3 = zero;
63 }
64 if (ks <= 4) {
65 i4 = zero;
66 }
67 if (ks < 6) {
68 i5 = zero;
69 }
70 if (ks <= 6) {
71 i6 = zero;
72 }
73 if (ks < 8) {
74 i7 = zero;
75 }
76 if (ks <= 8) {
77 i8 = zero;
78 }
79
80 size_t k = kc;
81 while (k >= 8) {
82 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
83 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
84 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
85 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
86 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
87 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
88 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
89 const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
90 const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
91
92 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
93 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
94 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
95 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
96
97 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
98 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
99 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
100
101 int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
102 int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
103
104 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
105 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
106
107 #if XNN_ARCH_ARM64
108 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
109 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
110 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
111 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
112
113 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
114 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
115 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
116 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
117 #else
118 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
119 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
120 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
121 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
122
123 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
124 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
125 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
126 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
127 #endif
128
129 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
130 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
131 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
132 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
133
134 #if XNN_ARCH_ARM64
135 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
136 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
137
138 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
139 #else
140 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
141 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
142
143 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
144 #endif
145
146 uint8x8_t vout = vqmovun_s16(vacc);
147 vout = vmax_u8(vout, voutput_min);
148 vout = vmin_u8(vout, voutput_max);
149
150 vst1_u8(output, vout); output += 8;
151
152 k -= 8;
153 }
154 if (k != 0) {
155 const uint8x8_t vi0 = vld1_u8(i0);
156 const uint8x8_t vi1 = vld1_u8(i1);
157 const uint8x8_t vi2 = vld1_u8(i2);
158 const uint8x8_t vi3 = vld1_u8(i3);
159 const uint8x8_t vi4 = vld1_u8(i4);
160 const uint8x8_t vi5 = vld1_u8(i5);
161 const uint8x8_t vi6 = vld1_u8(i6);
162 const uint8x8_t vi7 = vld1_u8(i7);
163 const uint8x8_t vi8 = vld1_u8(i8);
164
165 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
166 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
167 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
168 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
169
170 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
171 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
172 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
173
174 int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
175 int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
176
177 const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
178 const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
179
180 #if XNN_ARCH_ARM64
181 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
182 const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
183 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
184 const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
185
186 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
187 const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
188 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
189 const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
190 #else
191 const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
192 const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
193 const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
194 const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
195
196 const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
197 const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
198 const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
199 const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
200 #endif
201
202 const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
203 const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
204 const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
205 const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
206
207 #if XNN_ARCH_ARM64
208 vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
209 vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
210
211 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
212 #else
213 vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
214 vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
215
216 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
217 #endif
218
219 uint8x8_t vout = vqmovun_s16(vacc);
220 vout = vmax_u8(vout, voutput_min);
221 vout = vmin_u8(vout, voutput_max);
222
223 if (k & 4) {
224 vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
225 vout = vext_u8(vout, vout, 4);
226 }
227 if (k & 2) {
228 vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
229 vout = vext_u8(vout, vout, 2);
230 }
231 if (k & 1) {
232 vst1_lane_u8(output, vout, 0); output += 1;
233 }
234 }
235 output = (uint8_t*) ((uintptr_t) output + output_increment);
236 } while (--n != 0);
237 }
238