• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert DATATYPE in ["QS8", "QU8"]
7$assert CHANNEL_TILE % 8 == 0
8$assert CHANNEL_TILE >= 8
9$assert ROW_TILE >= 3
10$assert ROW_SUBTILE >= 3
11$assert ROW_SUBTILE <= ROW_TILE
12$assert REQUANTIZATION in ["FP32", "RNDNU"]
13$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
14#include <assert.h>
15
16#include <arm_neon.h>
17
18#include <xnnpack/gavgpool.h>
19$if ARMV8:
20  #include <xnnpack/intrinsics-polyfill.h>
21#include <xnnpack/math.h>
22
23
24$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon")
25$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
26$XINT8X8_T = {"QS8": "int8x8_t", "QU8": "uint8x8_t"}[DATATYPE]
27$XINT8X16_T = {"QS8": "int8x16_t", "QU8": "uint8x16_t"}[DATATYPE]
28$XINT16X8_T = {"QS8": "int16x8_t", "QU8": "uint16x8_t"}[DATATYPE]
29$VLD1_X8 = {"QS8": "vld1_s8", "QU8": "vld1_u8"}[DATATYPE]
30$VLD1_DUP_X8 = {"QS8": "vld1_dup_s8", "QU8": "vld1_dup_u8"}[DATATYPE]
31$VLD1Q_DUP_X8 = {"QS8": "vld1q_dup_s8", "QU8": "vld1q_dup_u8"}[DATATYPE]
32$VST1_X8 = {"QS8": "vst1_s8", "QU8": "vst1_u8"}[DATATYPE]
33$VST1Q_X8 = {"QS8": "vst1q_s8", "QU8": "vst1q_u8"}[DATATYPE]
34$VST1_LANE_X8 = {"QS8": "vst1_lane_s8", "QU8": "vst1_lane_u8"}[DATATYPE]
35$VADDL_X8 = {"QS8": "vaddl_s8", "QU8": "vaddl_u8"}[DATATYPE]
36$VADDW_X8 = {"QS8": "vaddw_s8", "QU8": "vaddw_u8"}[DATATYPE]
37$VMIN_X8 = {"QS8": "vmin_s8", "QU8": "vmin_u8"}[DATATYPE]
38$VMINQ_X8 = {"QS8": "vminq_s8", "QU8": "vminq_u8"}[DATATYPE]
39$VMAX_X8 = {"QS8": "vmax_s8", "QU8": "vmax_u8"}[DATATYPE]
40$VMAXQ_X8 = {"QS8": "vmaxq_s8", "QU8": "vmaxq_u8"}[DATATYPE]
41$VEXT_X8 = {"QS8": "vext_s8", "QU8": "vext_u8"}[DATATYPE]
42$VQMOVXN_S16 = {"QS8": "vqmovn_s16", "QU8": "vqmovun_s16"}[DATATYPE]
43$VQMOVXN_HIGH_S16 = {"QS8": "vqmovn_high_s16", "QU8": "vqmovun_high_s16"}[DATATYPE]
44$VGET_LOW_X8 = {"QS8": "vget_low_s8", "QU8": "vget_low_u8"}[DATATYPE]
45$VCOMBINE_X8 = {"QS8": "vcombine_s8", "QU8": "vcombine_u8"}[DATATYPE]
46$VREINTERPRET_U32_X8 = {"QS8": "vreinterpret_u32_s8", "QU8": "vreinterpret_u32_u8"}[DATATYPE]
47$VREINTERPRET_U16_X8 = {"QS8": "vreinterpret_u16_s8", "QU8": "vreinterpret_u16_u8"}[DATATYPE]
48$ISA = "neonv8" if ARMV8 else "neon"
49void xnn_${DATATYPE.lower()}_gavgpool_minmax_${REQUANTIZATION.lower()}_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__${ISA}_c${CHANNEL_TILE}(
50    size_t rows,
51    size_t channels,
52    const ${XINT8_T}* input,
53    size_t input_stride,
54    const ${XINT8_T}* zero,
55    int32_t* buffer,
56    ${XINT8_T}* output,
57    const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
58{
59  assert(rows > ${ROW_TILE});
60  assert(channels != 0);
61
62  const ${XINT8_T}* i0 = input;
63  $for M in range(1, ROW_TILE):
64    const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride);
65  $if CHANNEL_TILE <= 16:
66    const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T});
67  $else:
68    const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, 8) * sizeof(${XINT8_T});
69
70  const int32x4_t vinit_bias = vld1q_dup_s32(&params->${PARAMS_STRUCT}.init_bias);
71  int32_t* b = buffer;
72  size_t c = channels;
73  for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 16 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 16 else "c = doz(c, %d)") % CHANNEL_TILE}) {
74    $for M in range(2):
75      $for C in range(0, CHANNEL_TILE, 8):
76        const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8;
77
78    $for C in range(0, CHANNEL_TILE, 8):
79      const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8;
80      ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]});
81
82    $for M in range(2, ROW_TILE):
83      $for C in range(0, CHANNEL_TILE, 8):
84        $if M + 1 != ROW_TILE:
85          const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
86        vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]});
87
88    $for C in range(0, CHANNEL_TILE, 8):
89      $if DATATYPE == "QS8":
90        const int32x4_t vacc${ABC[C:C+4]} = vaddw_s16(vinit_bias, vget_low_s16(vsum${ABC[C:C+8]}));
91        const int32x4_t vacc${ABC[C+4:C+8]} = vaddw_s16(vinit_bias, vget_high_s16(vsum${ABC[C:C+8]}));
92      $else:
93        const int32x4_t vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_low_u16(vsum${ABC[C:C+8]})));
94        const int32x4_t vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_high_u16(vsum${ABC[C:C+8]})));
95
96    $for C in range(0, CHANNEL_TILE, 4):
97      vst1q_s32(b, vacc${ABC[C:C+4]}); b += 4;
98  }
99  $if CHANNEL_TILE > 16:
100    if XNN_UNLIKELY(c != 0) {
101      do {
102        $for M in range(3):
103          const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8;
104        ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]});
105
106        $for M in range(2, ROW_TILE):
107          $if M + 1 != ROW_TILE:
108            const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
109          vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]});
110
111        $if DATATYPE == "QS8":
112          const int32x4_t vacc${ABC[0:4]} = vaddw_s16(vinit_bias, vget_low_s16(vsum${ABC[0:8]}));
113          const int32x4_t vacc${ABC[4:8]} = vaddw_s16(vinit_bias, vget_high_s16(vsum${ABC[0:8]}));
114        $else:
115          const int32x4_t vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_low_u16(vsum${ABC[0:8]})));
116          const int32x4_t vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_high_u16(vsum${ABC[0:8]})));
117
118        vst1q_s32(b, vacc${ABC[0:4]}); b += 4;
119        vst1q_s32(b, vacc${ABC[4:8]}); b += 4;
120
121        c = doz(c, 8);
122      } while (c != 0);
123    }
124
125  for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) {
126    $for M in range(ROW_SUBTILE):
127      i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
128
129    int32_t* b = buffer;
130    size_t c = channels;
131    for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 16 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 16 else "c = doz(c, %d)") % CHANNEL_TILE}) {
132      $for M in range(2):
133        $for C in range(0, CHANNEL_TILE, 8):
134          const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8;
135
136      $for C in range(0, CHANNEL_TILE, 8):
137        const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8;
138        ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]});
139
140      $for M in range(2, ROW_TILE):
141        $for C in range(0, CHANNEL_TILE, 8):
142          $if M + 1 != ROW_TILE:
143            const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
144          $else:
145            $if C == 0:
146              int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(b);
147            $else:
148              int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(b + ${C});
149            int32x4_t vacc${ABC[C+4:C+8]} = vld1q_s32(b + ${C+4});
150          vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]});
151
152      $for C in range(0, CHANNEL_TILE, 8):
153        $if DATATYPE == "QS8":
154          vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vsum${ABC[C:C+8]}));
155          vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vsum${ABC[C:C+8]}));
156        $else:
157          vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C:C+4]}), vget_low_u16(vsum${ABC[C:C+8]})));
158          vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C+4:C+8]}), vget_high_u16(vsum${ABC[C:C+8]})));
159
160      $for C in range(0, CHANNEL_TILE, 4):
161        vst1q_s32(b, vacc${ABC[C:C+4]}); b += 4;
162    }
163    $if CHANNEL_TILE > 16:
164      if XNN_UNLIKELY(c != 0) {
165        do {
166          $for M in range(3):
167            const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8;
168          ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]});
169
170          $for M in range(2, ROW_TILE):
171            $if M + 1 != ROW_TILE:
172              const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
173            $else:
174              int32x4_t vacc${ABC[0:4]} = vld1q_s32(b);
175              int32x4_t vacc${ABC[4:8]} = vld1q_s32(b + 4);
176            vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]});
177
178          $if DATATYPE == "QS8":
179            vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vsum${ABC[0:8]}));
180            vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vsum${ABC[0:8]}));
181          $else:
182            vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[0:4]}), vget_low_u16(vsum${ABC[0:8]})));
183            vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[4:8]}), vget_high_u16(vsum${ABC[0:8]})));
184
185          vst1q_s32(b, vacc${ABC[0:4]}); b += 4;
186          vst1q_s32(b, vacc${ABC[4:8]}); b += 4;
187
188          c = doz(c, 8);
189        } while (c != 0);
190      }
191  }
192
193  i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment);
194  $for M in range(1, ROW_SUBTILE):
195    i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
196    $if M % 2 == 1:
197      if XNN_UNPREDICTABLE(rows < ${M+1}) {
198        i${M} = zero;
199      }
200    $else:
201      if XNN_UNPREDICTABLE(rows <= ${M}) {
202        i${M} = zero;
203      }
204
205  $if REQUANTIZATION == "FP32":
206    const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
207    $if ARMV8:
208      const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->fp32_neonv8.output_zero_point);
209    $else:
210      const float32x4_t vmagic_bias = vld1q_dup_f32(&params->fp32_neon.magic_bias);
211      const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->fp32_neon.magic_bias_less_output_zero_point);
212  $elif REQUANTIZATION == "RNDNU":
213    const int32x4_t vleft_pre_shift = vld1q_dup_s32(&params->rndnu_neon.left_pre_shift);
214    const int32x4_t vmultiplier = vld1q_dup_s32(&params->rndnu_neon.multiplier);
215    const int32x4_t vleft_post_shift = vld1q_dup_s32(&params->rndnu_neon.left_post_shift);
216    const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->rndnu_neon.output_zero_point);
217  $if CHANNEL_TILE > 8:
218    const ${XINT8X16_T} voutput_min = ${VLD1Q_DUP_X8}(&params->${PARAMS_STRUCT}.output_min);
219    const ${XINT8X16_T} voutput_max = ${VLD1Q_DUP_X8}(&params->${PARAMS_STRUCT}.output_max);
220  $else:
221    const ${XINT8X8_T} voutput_min = ${VLD1_DUP_X8}(&params->${PARAMS_STRUCT}.output_min);
222    const ${XINT8X8_T} voutput_max = ${VLD1_DUP_X8}(&params->${PARAMS_STRUCT}.output_max);
223  for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
224    $for M in range(2):
225      $for C in range(0, CHANNEL_TILE, 8):
226        const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8;
227
228    $for C in range(0, CHANNEL_TILE, 8):
229      const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8;
230      ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]});
231
232    $for M in range(2, ROW_TILE):
233      $for C in range(0, CHANNEL_TILE, 8):
234        $if M + 1 != ROW_TILE:
235          const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
236        $else:
237          int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(buffer); buffer += 4;
238          int32x4_t vacc${ABC[C+4:C+8]} = vld1q_s32(buffer); buffer += 4;
239        vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]});
240
241    $for C in range(0, CHANNEL_TILE, 8):
242      $if DATATYPE == "QS8":
243        vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vsum${ABC[C:C+8]}));
244        vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vsum${ABC[C:C+8]}));
245      $else:
246        vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C:C+4]}), vget_low_u16(vsum${ABC[C:C+8]})));
247        vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C+4:C+8]}), vget_high_u16(vsum${ABC[C:C+8]})));
248
249    $if REQUANTIZATION == "FP32":
250      $for C in range(0, CHANNEL_TILE, 4):
251        float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]});
252
253      $for C in range(0, CHANNEL_TILE, 4):
254        vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale);
255
256      $if ARMV8:
257        $for C in range(0, CHANNEL_TILE, 4):
258          vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]});
259      $else:
260        $for C in range(0, CHANNEL_TILE, 4):
261          vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias));
262
263        $for C in range(0, CHANNEL_TILE, 4):
264          vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point);
265    $elif REQUANTIZATION == "RNDNU":
266      $for C in range(0, CHANNEL_TILE, 4):
267        vacc${ABC[C:C+4]} = vqshlq_s32(vacc${ABC[C:C+4]}, vleft_pre_shift);
268
269      $for C in range(0, CHANNEL_TILE, 4):
270        vacc${ABC[C:C+4]} = vqdmulhq_s32(vacc${ABC[C:C+4]}, vmultiplier);
271
272      $for C in range(0, CHANNEL_TILE, 4):
273        vacc${ABC[C:C+4]} = vrshlq_s32(vacc${ABC[C:C+4]}, vleft_post_shift);
274
275    #if XNN_ARCH_ARM64
276      $for C in range(0, CHANNEL_TILE, 8):
277        int16x8_t vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]});
278    #else  // !XNN_ARCH_ARM64
279      $for C in range(0, CHANNEL_TILE, 8):
280        int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]}));
281    #endif  // !XNN_ARCH_ARM64
282
283    $if REQUANTIZATION != "FP32" or ARMV8:
284      $for C in range(0, CHANNEL_TILE, 8):
285        vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point);
286
287    #if XNN_ARCH_ARM64
288      $for C in range(0, CHANNEL_TILE, 16):
289        $if C + 8 < CHANNEL_TILE:
290          ${XINT8X16_T} vout${ABC[C:C+16]} = ${VQMOVXN_HIGH_S16}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]});
291        $else:
292          ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]});
293    #else  // !XNN_ARCH_ARM64
294      $for C in range(0, CHANNEL_TILE, 16):
295        $if C + 8 < CHANNEL_TILE:
296          ${XINT8X16_T} vout${ABC[C:C+16]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), ${VQMOVXN_S16}(vacc${ABC[C+8:C+16]}));
297        $else:
298          ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]});
299    #endif  // !XNN_ARCH_ARM64
300
301    $for C in range(0, CHANNEL_TILE, 16):
302      $if C + 8 < CHANNEL_TILE:
303        vout${ABC[C:C+16]} = ${VMAXQ_X8}(vout${ABC[C:C+16]}, voutput_min);
304      $elif CHANNEL_TILE > 8:
305        vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_min));
306      $else:
307        vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, voutput_min);
308
309    $for C in range(0, CHANNEL_TILE, 16):
310      $if C + 8 < CHANNEL_TILE:
311        vout${ABC[C:C+16]} = ${VMINQ_X8}(vout${ABC[C:C+16]}, voutput_max);
312      $elif CHANNEL_TILE > 8:
313        vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_max));
314      $else:
315        vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, voutput_max);
316
317    $for C in range(0, CHANNEL_TILE, 16):
318      $if C + 8 < CHANNEL_TILE:
319        ${VST1Q_X8}(output, vout${ABC[C:C+16]}); output += 16;
320      $else:
321        ${VST1_X8}(output, vout${ABC[C:C+8]}); output += 8;
322  }
323  if XNN_UNLIKELY(channels != 0) {
324    ${"do " if CHANNEL_TILE > 8 else ""}{
325      $for M in range(3):
326        $if CHANNEL_TILE > 8:
327          const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8;
328        $else:
329          const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M});
330      ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]});
331
332      $for M in range(2, ROW_TILE):
333        $if M + 1 != ROW_TILE:
334          $if CHANNEL_TILE > 8:
335            const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8;
336          $else:
337            const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1});
338        $else:
339          int32x4_t vacc${ABC[0:4]} = vld1q_s32(buffer); buffer += 4;
340          int32x4_t vacc${ABC[4:8]} = vld1q_s32(buffer); buffer += 4;
341        vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]});
342
343      $if DATATYPE == "QS8":
344        vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vsum${ABC[0:8]}));
345        vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vsum${ABC[0:8]}));
346      $else:
347        vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[0:4]}), vget_low_u16(vsum${ABC[0:8]})));
348        vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[4:8]}), vget_high_u16(vsum${ABC[0:8]})));
349
350      $if REQUANTIZATION == "FP32":
351        float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]});
352        float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]});
353
354        vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale);
355        vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale);
356
357        $if ARMV8:
358          vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]});
359          vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]});
360        $else:
361          vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias));
362          vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias));
363
364          vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point);
365          vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point);
366      $elif REQUANTIZATION == "RNDNU":
367        vacc${ABC[0:4]} = vqshlq_s32(vacc${ABC[0:4]}, vleft_pre_shift);
368        vacc${ABC[4:8]} = vqshlq_s32(vacc${ABC[4:8]}, vleft_pre_shift);
369
370        vacc${ABC[0:4]} = vqdmulhq_s32(vacc${ABC[0:4]}, vmultiplier);
371        vacc${ABC[4:8]} = vqdmulhq_s32(vacc${ABC[4:8]}, vmultiplier);
372
373        vacc${ABC[0:4]} = vrshlq_s32(vacc${ABC[0:4]}, vleft_post_shift);
374        vacc${ABC[4:8]} = vrshlq_s32(vacc${ABC[4:8]}, vleft_post_shift);
375
376      #if XNN_ARCH_ARM64
377        int16x8_t vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]});
378      #else
379        int16x8_t vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]}));
380      #endif
381      $if REQUANTIZATION != "FP32" or ARMV8:
382        vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point);
383
384      ${XINT8X8_T} vout${ABC[0:8]} = ${VQMOVXN_S16}(vacc${ABC[0:8]});
385      $if CHANNEL_TILE > 8:
386        vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_min));
387        vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_max));
388
389        if XNN_LIKELY(channels >= 8) {
390          ${VST1_X8}(output, vout${ABC[0:8]}); output += 8;
391          channels -= 8;
392        } else {
393          if (channels & 4) {
394            vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4;
395            vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
396          }
397          if (channels & 2) {
398            vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2;
399            vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
400          }
401          if (channels & 1) {
402            ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0); output += 1;
403          }
404          channels = 0;
405        }
406      $else:
407        vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, voutput_min);
408        vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, voutput_max);
409
410        if (channels & 4) {
411          vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4;
412          vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
413        }
414        if (channels & 2) {
415          vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2;
416          vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
417        }
418        if (channels & 1) {
419          ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0);
420        }
421    }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""}
422  }
423}
424