1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <arm_neon.h>
10
11 #include <xnnpack/common.h>
12 #include <xnnpack/dwconv.h>
13
14
xnn_q8_dwconv_ukernel_up8x9__neon(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union xnn_q8_gemm_params params[restrict static1])15 void xnn_q8_dwconv_ukernel_up8x9__neon(
16 size_t channels,
17 size_t output_width,
18 const uint8_t** input,
19 const void* weights,
20 uint8_t* output,
21 size_t input_stride,
22 size_t output_increment,
23 const union xnn_q8_gemm_params params[restrict static 1])
24 {
25 const uint8x8_t vkernel_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
26 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
27 const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
28 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
29 const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
30 const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
31
32 #if XNN_ARCH_ARM64
33 // Larger number of registers on AArch64 make it possible to process few pixels at a time.
34 if (input_stride == 3 * sizeof(void*)) {
35 for (; output_width >= 3; output_width -= 3) {
36 const uint8_t* i00 = input[ 0];
37 const uint8_t* i10 = input[ 1];
38 const uint8_t* i20 = input[ 2];
39 const uint8_t* i01 = input[ 3];
40 const uint8_t* i11 = input[ 4];
41 const uint8_t* i21 = input[ 5];
42 const uint8_t* i02 = input[ 6];
43 const uint8_t* i12 = input[ 7];
44 const uint8_t* i22 = input[ 8];
45 const uint8_t* i03 = input[ 9];
46 const uint8_t* i13 = input[10];
47 const uint8_t* i23 = input[11];
48 const uint8_t* i04 = input[12];
49 const uint8_t* i14 = input[13];
50 const uint8_t* i24 = input[14];
51
52 uint8_t* output0 = output;
53 uint8_t* output1 = output0 + channels + output_increment;
54 uint8_t* output2 = output1 + channels + output_increment;
55
56 input += 9;
57
58 size_t c = channels;
59 const void* w = weights;
60 for (; c >= 8; c -= 8) {
61 int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
62 int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
63 int32x4_t vacc1_lo = vacc0_lo;
64 int32x4_t vacc2_lo = vacc0_lo;
65 int32x4_t vacc1_hi = vacc0_hi;
66 int32x4_t vacc2_hi = vacc0_hi;
67
68 const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
69 const uint8x8_t vi00 = vld1_u8(i00); i00 += 8;
70 const uint8x8_t vi01 = vld1_u8(i01); i01 += 8;
71 const uint8x8_t vi02 = vld1_u8(i02); i02 += 8;
72 const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
73 const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
74 const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
75 const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
76 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
77 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
78 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
79 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
80 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
81 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
82
83 const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
84 const uint8x8_t vi10 = vld1_u8(i10); i10 += 8;
85 const uint8x8_t vi11 = vld1_u8(i11); i11 += 8;
86 const uint8x8_t vi12 = vld1_u8(i12); i12 += 8;
87 const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
88 const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
89 const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
90 const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
91 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
92 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
93 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
94 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
95 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
96 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
97
98 const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
99 const uint8x8_t vi20 = vld1_u8(i20); i20 += 8;
100 const uint8x8_t vi21 = vld1_u8(i21); i21 += 8;
101 const uint8x8_t vi22 = vld1_u8(i22); i22 += 8;
102 const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
103 const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
104 const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
105 const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
106 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
107 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
108 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
109 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
110 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
111 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
112
113 const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
114 const uint8x8_t vi03 = vld1_u8(i03); i03 += 8;
115 const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
116 const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
117 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
118 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
119 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
120 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
121 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
122 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
123
124 const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
125 const uint8x8_t vi13 = vld1_u8(i13); i13 += 8;
126 const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
127 const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
128 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
129 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
130 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
131 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
132 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
133 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
134
135 const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
136 const uint8x8_t vi23 = vld1_u8(i23); i23 += 8;
137 const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
138 const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
139 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
140 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
141 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
142 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
143 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
144 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
145
146 const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
147 const uint8x8_t vi04 = vld1_u8(i04); i04 += 8;
148 const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
149 const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
150 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
151 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
152 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
153 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
154 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
155 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
156
157 const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
158 const uint8x8_t vi14 = vld1_u8(i14); i14 += 8;
159 const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
160 const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
161 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
162 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
163 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
164 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
165 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
166 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
167
168 const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
169 const uint8x8_t vi24 = vld1_u8(i24); i24 += 8;
170 const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
171 const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
172 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
173 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
174 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
175 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
176 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
177 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
178
179 vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
180 vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
181 vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
182 vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
183 vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
184 vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
185
186 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
187 vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
188 vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
189 vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
190 vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
191 vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
192 vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
193
194 vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
195 vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
196 vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
197 vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
198 vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
199 vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
200
201 const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
202 const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
203 const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
204 uint8x8_t vout0 = vqmovun_s16(vacc0);
205 uint8x8_t vout1 = vqmovun_s16(vacc1);
206 uint8x8_t vout2 = vqmovun_s16(vacc2);
207 vout0 = vmax_u8(vout0, voutput_min);
208 vout1 = vmax_u8(vout1, voutput_min);
209 vout2 = vmax_u8(vout2, voutput_min);
210 vout0 = vmin_u8(vout0, voutput_max);
211 vout1 = vmin_u8(vout1, voutput_max);
212 vout2 = vmin_u8(vout2, voutput_max);
213
214 vst1_u8(output0, vout0); output0 += 8;
215 vst1_u8(output1, vout1); output1 += 8;
216 vst1_u8(output2, vout2); output2 += 8;
217 }
218 if (c != 0) {
219 int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
220 int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
221 int32x4_t vacc1_lo = vacc0_lo;
222 int32x4_t vacc2_lo = vacc0_lo;
223 int32x4_t vacc1_hi = vacc0_hi;
224 int32x4_t vacc2_hi = vacc0_hi;
225
226 const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
227 const uint8x8_t vi00 = vld1_u8(i00);
228 const uint8x8_t vi01 = vld1_u8(i01);
229 const uint8x8_t vi02 = vld1_u8(i02);
230 const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
231 const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
232 const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
233 const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
234 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
235 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
236 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
237 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
238 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
239 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
240
241 const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
242 const uint8x8_t vi10 = vld1_u8(i10);
243 const uint8x8_t vi11 = vld1_u8(i11);
244 const uint8x8_t vi12 = vld1_u8(i12);
245 const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
246 const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
247 const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
248 const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
249 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
250 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
251 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
252 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
253 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
254 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
255
256 const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
257 const uint8x8_t vi20 = vld1_u8(i20);
258 const uint8x8_t vi21 = vld1_u8(i21);
259 const uint8x8_t vi22 = vld1_u8(i22);
260 const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
261 const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
262 const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
263 const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
264 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
265 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
266 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
267 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
268 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
269 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
270
271 const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
272 const uint8x8_t vi03 = vld1_u8(i03);
273 const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
274 const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
275 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
276 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
277 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
278 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
279 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
280 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
281
282 const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
283 const uint8x8_t vi13 = vld1_u8(i13);
284 const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
285 const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
286 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
287 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
288 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
289 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
290 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
291 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
292
293 const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
294 const uint8x8_t vi23 = vld1_u8(i23);
295 const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
296 const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
297 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
298 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
299 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
300 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
301 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
302 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
303
304 const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
305 const uint8x8_t vi04 = vld1_u8(i04);
306 const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
307 const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
308 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
309 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
310 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
311 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
312 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
313 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
314
315 const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
316 const uint8x8_t vi14 = vld1_u8(i14);
317 const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
318 const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
319 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
320 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
321 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
322 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
323 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
324 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
325
326 const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
327 const uint8x8_t vi24 = vld1_u8(i24);
328 const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
329 const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
330 vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
331 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
332 vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
333 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
334 vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
335 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
336
337 vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
338 vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
339 vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
340 vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
341 vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
342 vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
343
344 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
345 vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
346 vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
347 vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
348 vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
349 vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
350 vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
351
352 vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
353 vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
354 vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
355 vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
356 vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
357 vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
358
359 const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
360 const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
361 const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
362 uint8x8_t vout0 = vqmovun_s16(vacc0);
363 uint8x8_t vout1 = vqmovun_s16(vacc1);
364 uint8x8_t vout2 = vqmovun_s16(vacc2);
365 vout0 = vmax_u8(vout0, voutput_min);
366 vout1 = vmax_u8(vout1, voutput_min);
367 vout2 = vmax_u8(vout2, voutput_min);
368 vout0 = vmin_u8(vout0, voutput_max);
369 vout1 = vmin_u8(vout1, voutput_max);
370 vout2 = vmin_u8(vout2, voutput_max);
371
372 if (c & 4) {
373 vst1_lane_u32(__builtin_assume_aligned(output0, 1), vreinterpret_u32_u8(vout0), 0); output0 += 4;
374 vst1_lane_u32(__builtin_assume_aligned(output1, 1), vreinterpret_u32_u8(vout1), 0); output1 += 4;
375 vst1_lane_u32(__builtin_assume_aligned(output2, 1), vreinterpret_u32_u8(vout2), 0); output2 += 4;
376 vout0 = vext_u8(vout0, vout0, 4);
377 vout1 = vext_u8(vout1, vout1, 4);
378 vout2 = vext_u8(vout2, vout2, 4);
379 }
380 if (c & 2) {
381 vst1_lane_u16(__builtin_assume_aligned(output0, 1), vreinterpret_u16_u8(vout0), 0); output0 += 2;
382 vst1_lane_u16(__builtin_assume_aligned(output1, 1), vreinterpret_u16_u8(vout1), 0); output1 += 2;
383 vst1_lane_u16(__builtin_assume_aligned(output2, 1), vreinterpret_u16_u8(vout2), 0); output2 += 2;
384 vout0 = vext_u8(vout0, vout0, 2);
385 vout1 = vext_u8(vout1, vout1, 2);
386 vout2 = vext_u8(vout2, vout2, 2);
387 }
388 if (c & 1) {
389 vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0); output0++;
390 vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0); output1++;
391 vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0); output2++;
392 }
393 }
394
395 output = (uint8_t*) ((uintptr_t) output2 + output_increment);
396 }
397 if (output_width == 0) {
398 return;
399 }
400 }
401 #endif // XNN_ARCH_ARM64
402
403 do {
404 const uint8_t* i0 = input[0];
405 const uint8_t* i1 = input[1];
406 const uint8_t* i2 = input[2];
407 const uint8_t* i3 = input[3];
408 const uint8_t* i4 = input[4];
409 const uint8_t* i5 = input[5];
410 const uint8_t* i6 = input[6];
411 const uint8_t* i7 = input[7];
412 const uint8_t* i8 = input[8];
413
414 input = (const uint8_t**) ((uintptr_t) input + input_stride);
415
416 size_t c = channels;
417 const void* w = weights;
418 for (; c >= 8; c -= 8) {
419 int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
420 int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
421
422 const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
423 const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
424 const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
425 const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
426 int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
427 int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
428
429 const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
430 const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
431 const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
432 const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
433 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
434 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
435
436 const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
437 const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
438 const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
439 const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
440 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
441 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
442
443 const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
444 const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
445 const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
446 const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
447 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
448 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
449
450 const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
451 const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
452 const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
453 const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
454 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
455 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
456
457 const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
458 const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
459 const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
460 const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
461 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
462 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
463
464 const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
465 const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
466 const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
467 const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
468 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
469 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
470
471 const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
472 const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
473 const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
474 const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
475 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
476 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
477
478 const uint8x8_t vk8 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
479 const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
480 const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
481 const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
482 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
483 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
484
485 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
486 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
487
488 vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
489 vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
490
491 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
492 vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
493 vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
494
495 vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
496 vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
497
498 #if XNN_ARCH_ARM64
499 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
500 #else
501 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
502 #endif
503 uint8x8_t vout = vqmovun_s16(vacc);
504 vout = vmax_u8(vout, voutput_min);
505 vout = vmin_u8(vout, voutput_max);
506
507 vst1_u8(output, vout); output += 8;
508 }
509 if (c != 0) {
510 int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
511 int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
512
513 const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
514 const uint8x8_t vi0 = vld1_u8(i0);
515 const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
516 const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
517 int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
518 int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
519
520 const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
521 const uint8x8_t vi1 = vld1_u8(i1);
522 const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
523 const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
524 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
525 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
526
527 const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
528 const uint8x8_t vi2 = vld1_u8(i2);
529 const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
530 const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
531 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
532 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
533
534 const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
535 const uint8x8_t vi3 = vld1_u8(i3);
536 const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
537 const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
538 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
539 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
540
541 const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
542 const uint8x8_t vi4 = vld1_u8(i4);
543 const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
544 const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
545 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
546 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
547
548 const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
549 const uint8x8_t vi5 = vld1_u8(i5);
550 const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
551 const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
552 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
553 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
554
555 const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
556 const uint8x8_t vi6 = vld1_u8(i6);
557 const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
558 const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
559 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
560 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
561
562 const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
563 const uint8x8_t vi7 = vld1_u8(i7);
564 const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
565 const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
566 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
567 vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
568
569 const uint8x8_t vk8 = vld1_u8(w);
570 const uint8x8_t vi8 = vld1_u8(i8);
571 const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
572 const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
573 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
574 vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
575
576 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
577 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
578
579 vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
580 vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
581
582 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
583 vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
584 vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
585
586 vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
587 vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
588
589 #if XNN_ARCH_ARM64
590 const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
591 #else
592 const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
593 #endif
594 uint8x8_t vout = vqmovun_s16(vacc);
595 vout = vmax_u8(vout, voutput_min);
596 vout = vmin_u8(vout, voutput_max);
597
598 if (c & 4) {
599 vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
600 vout = vext_u8(vout, vout, 4);
601 }
602 if (c & 2) {
603 vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
604 vout = vext_u8(vout, vout, 2);
605 }
606 if (c & 1) {
607 vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); output++;
608 }
609 }
610
611 output = (uint8_t*) ((uintptr_t) output + output_increment);
612 } while (--output_width != 0);
613 }
614