• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <arm_neon.h>
10 
11 #include <xnnpack/common.h>
12 #include <xnnpack/dwconv.h>
13 
14 
xnn_q8_dwconv_ukernel_up8x9__neon(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union xnn_q8_gemm_params params[restrict static1])15 void xnn_q8_dwconv_ukernel_up8x9__neon(
16     size_t channels,
17     size_t output_width,
18     const uint8_t** input,
19     const void* weights,
20     uint8_t* output,
21     size_t input_stride,
22     size_t output_increment,
23     const union xnn_q8_gemm_params params[restrict static 1])
24 {
25   const uint8x8_t vkernel_zero_point = vld1_dup_u8((const uint8_t*) &params->neon.kernel_zero_point);
26   const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
27   const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
28   const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
29   const uint8x8_t voutput_min = vld1_dup_u8(&params->neon.output_min);
30   const uint8x8_t voutput_max = vld1_dup_u8(&params->neon.output_max);
31 
32 #if XNN_ARCH_ARM64
33   // Larger number of registers on AArch64 make it possible to process few pixels at a time.
34   if (input_stride == 3 * sizeof(void*)) {
35     for (; output_width >= 3; output_width -= 3) {
36       const uint8_t* i00 = input[ 0];
37       const uint8_t* i10 = input[ 1];
38       const uint8_t* i20 = input[ 2];
39       const uint8_t* i01 = input[ 3];
40       const uint8_t* i11 = input[ 4];
41       const uint8_t* i21 = input[ 5];
42       const uint8_t* i02 = input[ 6];
43       const uint8_t* i12 = input[ 7];
44       const uint8_t* i22 = input[ 8];
45       const uint8_t* i03 = input[ 9];
46       const uint8_t* i13 = input[10];
47       const uint8_t* i23 = input[11];
48       const uint8_t* i04 = input[12];
49       const uint8_t* i14 = input[13];
50       const uint8_t* i24 = input[14];
51 
52       uint8_t* output0 = output;
53       uint8_t* output1 = output0 + channels + output_increment;
54       uint8_t* output2 = output1 + channels + output_increment;
55 
56       input += 9;
57 
58       size_t c = channels;
59       const void* w = weights;
60       for (; c >= 8; c -= 8) {
61         int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
62         int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
63         int32x4_t vacc1_lo = vacc0_lo;
64         int32x4_t vacc2_lo = vacc0_lo;
65         int32x4_t vacc1_hi = vacc0_hi;
66         int32x4_t vacc2_hi = vacc0_hi;
67 
68         const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
69         const uint8x8_t vi00 = vld1_u8(i00); i00 += 8;
70         const uint8x8_t vi01 = vld1_u8(i01); i01 += 8;
71         const uint8x8_t vi02 = vld1_u8(i02); i02 += 8;
72         const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
73         const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
74         const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
75         const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
76         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
77         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
78         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
79         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
80         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
81         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
82 
83         const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
84         const uint8x8_t vi10 = vld1_u8(i10); i10 += 8;
85         const uint8x8_t vi11 = vld1_u8(i11); i11 += 8;
86         const uint8x8_t vi12 = vld1_u8(i12); i12 += 8;
87         const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
88         const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
89         const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
90         const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
91         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
92         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
93         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
94         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
95         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
96         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
97 
98         const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
99         const uint8x8_t vi20 = vld1_u8(i20); i20 += 8;
100         const uint8x8_t vi21 = vld1_u8(i21); i21 += 8;
101         const uint8x8_t vi22 = vld1_u8(i22); i22 += 8;
102         const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
103         const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
104         const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
105         const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
106         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
107         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
108         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
109         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
110         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
111         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
112 
113         const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
114         const uint8x8_t vi03 = vld1_u8(i03); i03 += 8;
115         const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
116         const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
117         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
118         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
119         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
120         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
121         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
122         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
123 
124         const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
125         const uint8x8_t vi13 = vld1_u8(i13); i13 += 8;
126         const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
127         const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
128         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
129         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
130         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
131         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
132         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
133         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
134 
135         const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
136         const uint8x8_t vi23 = vld1_u8(i23); i23 += 8;
137         const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
138         const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
139         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
140         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
141         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
142         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
143         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
144         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
145 
146         const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
147         const uint8x8_t vi04 = vld1_u8(i04); i04 += 8;
148         const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
149         const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
150         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
151         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
152         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
153         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
154         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
155         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
156 
157         const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
158         const uint8x8_t vi14 = vld1_u8(i14); i14 += 8;
159         const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
160         const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
161         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
162         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
163         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
164         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
165         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
166         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
167 
168         const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
169         const uint8x8_t vi24 = vld1_u8(i24); i24 += 8;
170         const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
171         const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
172         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
173         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
174         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
175         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
176         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
177         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
178 
179         vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
180         vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
181         vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
182         vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
183         vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
184         vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
185 
186         const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
187         vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
188         vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
189         vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
190         vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
191         vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
192         vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
193 
194         vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
195         vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
196         vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
197         vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
198         vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
199         vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
200 
201         const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
202         const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
203         const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
204         uint8x8_t vout0 = vqmovun_s16(vacc0);
205         uint8x8_t vout1 = vqmovun_s16(vacc1);
206         uint8x8_t vout2 = vqmovun_s16(vacc2);
207         vout0 = vmax_u8(vout0, voutput_min);
208         vout1 = vmax_u8(vout1, voutput_min);
209         vout2 = vmax_u8(vout2, voutput_min);
210         vout0 = vmin_u8(vout0, voutput_max);
211         vout1 = vmin_u8(vout1, voutput_max);
212         vout2 = vmin_u8(vout2, voutput_max);
213 
214         vst1_u8(output0, vout0); output0 += 8;
215         vst1_u8(output1, vout1); output1 += 8;
216         vst1_u8(output2, vout2); output2 += 8;
217       }
218       if (c != 0) {
219         int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
220         int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
221         int32x4_t vacc1_lo = vacc0_lo;
222         int32x4_t vacc2_lo = vacc0_lo;
223         int32x4_t vacc1_hi = vacc0_hi;
224         int32x4_t vacc2_hi = vacc0_hi;
225 
226         const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
227         const uint8x8_t vi00 = vld1_u8(i00);
228         const uint8x8_t vi01 = vld1_u8(i01);
229         const uint8x8_t vi02 = vld1_u8(i02);
230         const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
231         const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
232         const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
233         const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
234         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
235         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
236         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
237         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
238         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
239         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
240 
241         const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
242         const uint8x8_t vi10 = vld1_u8(i10);
243         const uint8x8_t vi11 = vld1_u8(i11);
244         const uint8x8_t vi12 = vld1_u8(i12);
245         const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
246         const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
247         const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
248         const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
249         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
250         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
251         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
252         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
253         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
254         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
255 
256         const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
257         const uint8x8_t vi20 = vld1_u8(i20);
258         const uint8x8_t vi21 = vld1_u8(i21);
259         const uint8x8_t vi22 = vld1_u8(i22);
260         const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
261         const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
262         const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
263         const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
264         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
265         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
266         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
267         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
268         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
269         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
270 
271         const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
272         const uint8x8_t vi03 = vld1_u8(i03);
273         const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
274         const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
275         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
276         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
277         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
278         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
279         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
280         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
281 
282         const uint8x8_t vk11 = vld1_u8(w);  w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
283         const uint8x8_t vi13 = vld1_u8(i13);
284         const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
285         const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
286         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
287         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
288         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
289         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
290         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
291         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
292 
293         const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
294         const uint8x8_t vi23 = vld1_u8(i23);
295         const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
296         const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
297         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
298         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
299         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
300         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
301         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
302         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
303 
304         const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
305         const uint8x8_t vi04 = vld1_u8(i04);
306         const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
307         const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
308         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
309         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
310         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
311         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
312         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
313         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
314 
315         const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
316         const uint8x8_t vi14 = vld1_u8(i14);
317         const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
318         const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
319         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
320         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
321         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
322         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
323         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
324         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
325 
326         const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
327         const uint8x8_t vi24 = vld1_u8(i24);
328         const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
329         const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
330         vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
331         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
332         vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
333         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
334         vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
335         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
336 
337         vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
338         vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
339         vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
340         vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
341         vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
342         vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
343 
344         const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
345         vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
346         vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
347         vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
348         vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
349         vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
350         vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
351 
352         vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
353         vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
354         vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
355         vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
356         vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
357         vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
358 
359         const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
360         const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
361         const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
362         uint8x8_t vout0 = vqmovun_s16(vacc0);
363         uint8x8_t vout1 = vqmovun_s16(vacc1);
364         uint8x8_t vout2 = vqmovun_s16(vacc2);
365         vout0 = vmax_u8(vout0, voutput_min);
366         vout1 = vmax_u8(vout1, voutput_min);
367         vout2 = vmax_u8(vout2, voutput_min);
368         vout0 = vmin_u8(vout0, voutput_max);
369         vout1 = vmin_u8(vout1, voutput_max);
370         vout2 = vmin_u8(vout2, voutput_max);
371 
372         if (c & 4) {
373           vst1_lane_u32(__builtin_assume_aligned(output0, 1), vreinterpret_u32_u8(vout0), 0); output0 += 4;
374           vst1_lane_u32(__builtin_assume_aligned(output1, 1), vreinterpret_u32_u8(vout1), 0); output1 += 4;
375           vst1_lane_u32(__builtin_assume_aligned(output2, 1), vreinterpret_u32_u8(vout2), 0); output2 += 4;
376           vout0 = vext_u8(vout0, vout0, 4);
377           vout1 = vext_u8(vout1, vout1, 4);
378           vout2 = vext_u8(vout2, vout2, 4);
379         }
380         if (c & 2) {
381           vst1_lane_u16(__builtin_assume_aligned(output0, 1), vreinterpret_u16_u8(vout0), 0); output0 += 2;
382           vst1_lane_u16(__builtin_assume_aligned(output1, 1), vreinterpret_u16_u8(vout1), 0); output1 += 2;
383           vst1_lane_u16(__builtin_assume_aligned(output2, 1), vreinterpret_u16_u8(vout2), 0); output2 += 2;
384           vout0 = vext_u8(vout0, vout0, 2);
385           vout1 = vext_u8(vout1, vout1, 2);
386           vout2 = vext_u8(vout2, vout2, 2);
387         }
388         if (c & 1) {
389           vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0); output0++;
390           vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0); output1++;
391           vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0); output2++;
392         }
393       }
394 
395       output = (uint8_t*) ((uintptr_t) output2 + output_increment);
396     }
397     if (output_width == 0) {
398       return;
399     }
400   }
401 #endif  // XNN_ARCH_ARM64
402 
403   do {
404     const uint8_t* i0 = input[0];
405     const uint8_t* i1 = input[1];
406     const uint8_t* i2 = input[2];
407     const uint8_t* i3 = input[3];
408     const uint8_t* i4 = input[4];
409     const uint8_t* i5 = input[5];
410     const uint8_t* i6 = input[6];
411     const uint8_t* i7 = input[7];
412     const uint8_t* i8 = input[8];
413 
414     input = (const uint8_t**) ((uintptr_t) input + input_stride);
415 
416     size_t c = channels;
417     const void* w = weights;
418     for (; c >= 8; c -= 8) {
419       int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
420       int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
421 
422       const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
423       const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
424       const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
425       const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
426       int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
427       int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
428 
429       const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
430       const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
431       const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
432       const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
433       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
434       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
435 
436       const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
437       const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
438       const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
439       const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
440       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
441       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
442 
443       const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
444       const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
445       const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
446       const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
447       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
448       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
449 
450       const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
451       const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
452       const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
453       const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
454       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
455       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
456 
457       const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
458       const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
459       const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
460       const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
461       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
462       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
463 
464       const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
465       const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
466       const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
467       const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
468       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
469       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
470 
471       const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
472       const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
473       const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
474       const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
475       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
476       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
477 
478       const uint8x8_t vk8 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
479       const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
480       const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
481       const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
482       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
483       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
484 
485       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
486       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
487 
488       vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
489       vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
490 
491       const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
492       vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
493       vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
494 
495       vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
496       vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
497 
498 #if XNN_ARCH_ARM64
499       const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
500 #else
501       const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
502 #endif
503       uint8x8_t vout = vqmovun_s16(vacc);
504       vout = vmax_u8(vout, voutput_min);
505       vout = vmin_u8(vout, voutput_max);
506 
507       vst1_u8(output, vout); output += 8;
508     }
509     if (c != 0) {
510       int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
511       int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
512 
513       const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
514       const uint8x8_t vi0 = vld1_u8(i0);
515       const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
516       const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
517       int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
518       int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
519 
520       const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
521       const uint8x8_t vi1 = vld1_u8(i1);
522       const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
523       const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
524       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
525       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
526 
527       const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
528       const uint8x8_t vi2 = vld1_u8(i2);
529       const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
530       const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
531       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
532       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
533 
534       const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
535       const uint8x8_t vi3 = vld1_u8(i3);
536       const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
537       const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
538       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
539       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
540 
541       const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
542       const uint8x8_t vi4 = vld1_u8(i4);
543       const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
544       const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
545       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
546       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
547 
548       const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
549       const uint8x8_t vi5 = vld1_u8(i5);
550       const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
551       const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
552       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
553       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
554 
555       const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
556       const uint8x8_t vi6 = vld1_u8(i6);
557       const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
558       const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
559       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
560       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
561 
562       const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
563       const uint8x8_t vi7 = vld1_u8(i7);
564       const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
565       const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
566       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
567       vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
568 
569       const uint8x8_t vk8 = vld1_u8(w);
570       const uint8x8_t vi8 = vld1_u8(i8);
571       const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
572       const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
573       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
574       vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
575 
576       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
577       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
578 
579       vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
580       vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
581 
582       const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
583       vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
584       vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
585 
586       vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
587       vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
588 
589 #if XNN_ARCH_ARM64
590       const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
591 #else
592       const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
593 #endif
594       uint8x8_t vout = vqmovun_s16(vacc);
595       vout = vmax_u8(vout, voutput_min);
596       vout = vmin_u8(vout, voutput_max);
597 
598       if (c & 4) {
599         vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
600         vout = vext_u8(vout, vout, 4);
601       }
602       if (c & 2) {
603         vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
604         vout = vext_u8(vout, vout, 2);
605       }
606       if (c & 1) {
607         vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); output++;
608       }
609     }
610 
611     output = (uint8_t*) ((uintptr_t) output + output_increment);
612   } while (--output_width != 0);
613 }
614