1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert CHANNEL_TILE % 8 == 0 8$assert CHANNEL_TILE >= 8 9$assert ROW_TILE >= 3 10$assert ROW_SUBTILE >= 3 11$assert ROW_SUBTILE <= ROW_TILE 12$assert REQUANTIZATION in ["FP32", "RNDNU"] 13$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 14#include <assert.h> 15 16#include <arm_neon.h> 17 18#include <xnnpack/gavgpool.h> 19$if ARMV8: 20 #include <xnnpack/intrinsics-polyfill.h> 21#include <xnnpack/math.h> 22 23 24$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon") 25$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 26$XINT8X8_T = {"QS8": "int8x8_t", "QU8": "uint8x8_t"}[DATATYPE] 27$XINT8X16_T = {"QS8": "int8x16_t", "QU8": "uint8x16_t"}[DATATYPE] 28$XINT16X8_T = {"QS8": "int16x8_t", "QU8": "uint16x8_t"}[DATATYPE] 29$VLD1_X8 = {"QS8": "vld1_s8", "QU8": "vld1_u8"}[DATATYPE] 30$VLD1_DUP_X8 = {"QS8": "vld1_dup_s8", "QU8": "vld1_dup_u8"}[DATATYPE] 31$VLD1Q_DUP_X8 = {"QS8": "vld1q_dup_s8", "QU8": "vld1q_dup_u8"}[DATATYPE] 32$VST1_X8 = {"QS8": "vst1_s8", "QU8": "vst1_u8"}[DATATYPE] 33$VST1Q_X8 = {"QS8": "vst1q_s8", "QU8": "vst1q_u8"}[DATATYPE] 34$VST1_LANE_X8 = {"QS8": "vst1_lane_s8", "QU8": "vst1_lane_u8"}[DATATYPE] 35$VADDL_X8 = {"QS8": "vaddl_s8", "QU8": "vaddl_u8"}[DATATYPE] 36$VADDW_X8 = {"QS8": "vaddw_s8", "QU8": "vaddw_u8"}[DATATYPE] 37$VMIN_X8 = {"QS8": "vmin_s8", "QU8": "vmin_u8"}[DATATYPE] 38$VMINQ_X8 = {"QS8": "vminq_s8", "QU8": "vminq_u8"}[DATATYPE] 39$VMAX_X8 = {"QS8": "vmax_s8", "QU8": "vmax_u8"}[DATATYPE] 40$VMAXQ_X8 = {"QS8": "vmaxq_s8", "QU8": "vmaxq_u8"}[DATATYPE] 41$VEXT_X8 = {"QS8": "vext_s8", "QU8": "vext_u8"}[DATATYPE] 42$VQMOVXN_S16 = {"QS8": "vqmovn_s16", "QU8": "vqmovun_s16"}[DATATYPE] 43$VQMOVXN_HIGH_S16 = {"QS8": "vqmovn_high_s16", "QU8": "vqmovun_high_s16"}[DATATYPE] 44$VGET_LOW_X8 = {"QS8": "vget_low_s8", "QU8": "vget_low_u8"}[DATATYPE] 45$VCOMBINE_X8 = {"QS8": "vcombine_s8", "QU8": "vcombine_u8"}[DATATYPE] 46$VREINTERPRET_U32_X8 = {"QS8": "vreinterpret_u32_s8", "QU8": "vreinterpret_u32_u8"}[DATATYPE] 47$VREINTERPRET_U16_X8 = {"QS8": "vreinterpret_u16_s8", "QU8": "vreinterpret_u16_u8"}[DATATYPE] 48$ISA = "neonv8" if ARMV8 else "neon" 49void xnn_${DATATYPE.lower()}_gavgpool_minmax_${REQUANTIZATION.lower()}_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__${ISA}_c${CHANNEL_TILE}( 50 size_t rows, 51 size_t channels, 52 const ${XINT8_T}* input, 53 size_t input_stride, 54 const ${XINT8_T}* zero, 55 int32_t* buffer, 56 ${XINT8_T}* output, 57 const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 58{ 59 assert(rows > ${ROW_TILE}); 60 assert(channels != 0); 61 62 const ${XINT8_T}* i0 = input; 63 $for M in range(1, ROW_TILE): 64 const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride); 65 $if CHANNEL_TILE <= 16: 66 const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T}); 67 $else: 68 const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, 8) * sizeof(${XINT8_T}); 69 70 const int32x4_t vinit_bias = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.init_bias); 71 int32_t* b = buffer; 72 size_t c = channels; 73 for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 16 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 16 else "c = doz(c, %d)") % CHANNEL_TILE}) { 74 $for M in range(2): 75 $for C in range(0, CHANNEL_TILE, 8): 76 const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8; 77 78 $for C in range(0, CHANNEL_TILE, 8): 79 const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8; 80 ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}); 81 82 $for M in range(2, ROW_TILE): 83 $for C in range(0, CHANNEL_TILE, 8): 84 $if M + 1 != ROW_TILE: 85 const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 86 vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]}); 87 88 $for C in range(0, CHANNEL_TILE, 8): 89 $if DATATYPE == "QS8": 90 const int32x4_t vacc${ABC[C:C+4]} = vaddw_s16(vinit_bias, vget_low_s16(vsum${ABC[C:C+8]})); 91 const int32x4_t vacc${ABC[C+4:C+8]} = vaddw_s16(vinit_bias, vget_high_s16(vsum${ABC[C:C+8]})); 92 $else: 93 const int32x4_t vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_low_u16(vsum${ABC[C:C+8]}))); 94 const int32x4_t vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_high_u16(vsum${ABC[C:C+8]}))); 95 96 $for C in range(0, CHANNEL_TILE, 4): 97 vst1q_s32(b, vacc${ABC[C:C+4]}); b += 4; 98 } 99 $if CHANNEL_TILE > 16: 100 if XNN_UNLIKELY(c != 0) { 101 do { 102 $for M in range(3): 103 const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8; 104 ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}); 105 106 $for M in range(2, ROW_TILE): 107 $if M + 1 != ROW_TILE: 108 const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 109 vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]}); 110 111 $if DATATYPE == "QS8": 112 const int32x4_t vacc${ABC[0:4]} = vaddw_s16(vinit_bias, vget_low_s16(vsum${ABC[0:8]})); 113 const int32x4_t vacc${ABC[4:8]} = vaddw_s16(vinit_bias, vget_high_s16(vsum${ABC[0:8]})); 114 $else: 115 const int32x4_t vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_low_u16(vsum${ABC[0:8]}))); 116 const int32x4_t vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vinit_bias), vget_high_u16(vsum${ABC[0:8]}))); 117 118 vst1q_s32(b, vacc${ABC[0:4]}); b += 4; 119 vst1q_s32(b, vacc${ABC[4:8]}); b += 4; 120 121 c = doz(c, 8); 122 } while (c != 0); 123 } 124 125 for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) { 126 $for M in range(ROW_SUBTILE): 127 i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 128 129 int32_t* b = buffer; 130 size_t c = channels; 131 for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 16 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 16 else "c = doz(c, %d)") % CHANNEL_TILE}) { 132 $for M in range(2): 133 $for C in range(0, CHANNEL_TILE, 8): 134 const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8; 135 136 $for C in range(0, CHANNEL_TILE, 8): 137 const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8; 138 ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}); 139 140 $for M in range(2, ROW_TILE): 141 $for C in range(0, CHANNEL_TILE, 8): 142 $if M + 1 != ROW_TILE: 143 const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 144 $else: 145 $if C == 0: 146 int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(b); 147 $else: 148 int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(b + ${C}); 149 int32x4_t vacc${ABC[C+4:C+8]} = vld1q_s32(b + ${C+4}); 150 vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]}); 151 152 $for C in range(0, CHANNEL_TILE, 8): 153 $if DATATYPE == "QS8": 154 vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vsum${ABC[C:C+8]})); 155 vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vsum${ABC[C:C+8]})); 156 $else: 157 vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C:C+4]}), vget_low_u16(vsum${ABC[C:C+8]}))); 158 vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C+4:C+8]}), vget_high_u16(vsum${ABC[C:C+8]}))); 159 160 $for C in range(0, CHANNEL_TILE, 4): 161 vst1q_s32(b, vacc${ABC[C:C+4]}); b += 4; 162 } 163 $if CHANNEL_TILE > 16: 164 if XNN_UNLIKELY(c != 0) { 165 do { 166 $for M in range(3): 167 const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8; 168 ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}); 169 170 $for M in range(2, ROW_TILE): 171 $if M + 1 != ROW_TILE: 172 const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 173 $else: 174 int32x4_t vacc${ABC[0:4]} = vld1q_s32(b); 175 int32x4_t vacc${ABC[4:8]} = vld1q_s32(b + 4); 176 vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]}); 177 178 $if DATATYPE == "QS8": 179 vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vsum${ABC[0:8]})); 180 vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vsum${ABC[0:8]})); 181 $else: 182 vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[0:4]}), vget_low_u16(vsum${ABC[0:8]}))); 183 vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[4:8]}), vget_high_u16(vsum${ABC[0:8]}))); 184 185 vst1q_s32(b, vacc${ABC[0:4]}); b += 4; 186 vst1q_s32(b, vacc${ABC[4:8]}); b += 4; 187 188 c = doz(c, 8); 189 } while (c != 0); 190 } 191 } 192 193 i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment); 194 $for M in range(1, ROW_SUBTILE): 195 i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 196 $if M % 2 == 1: 197 if XNN_UNPREDICTABLE(rows < ${M+1}) { 198 i${M} = zero; 199 } 200 $else: 201 if XNN_UNPREDICTABLE(rows <= ${M}) { 202 i${M} = zero; 203 } 204 205 $if REQUANTIZATION == "FP32": 206 const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); 207 $if ARMV8: 208 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->fp32_neonv8.output_zero_point); 209 $else: 210 const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->fp32_neon.magic_bias); 211 const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(¶ms->fp32_neon.magic_bias_less_output_zero_point); 212 $elif REQUANTIZATION == "RNDNU": 213 const int32x4_t vleft_pre_shift = vld1q_dup_s32(¶ms->rndnu_neon.left_pre_shift); 214 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->rndnu_neon.multiplier); 215 const int32x4_t vleft_post_shift = vld1q_dup_s32(¶ms->rndnu_neon.left_post_shift); 216 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->rndnu_neon.output_zero_point); 217 $if CHANNEL_TILE > 8: 218 const ${XINT8X16_T} voutput_min = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 219 const ${XINT8X16_T} voutput_max = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 220 $else: 221 const ${XINT8X8_T} voutput_min = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 222 const ${XINT8X8_T} voutput_max = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 223 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 224 $for M in range(2): 225 $for C in range(0, CHANNEL_TILE, 8): 226 const ${XINT8X8_T} vi${M}x${ABC[C:C+8]} = ${VLD1_X8}(i${M}); i${M} += 8; 227 228 $for C in range(0, CHANNEL_TILE, 8): 229 const ${XINT8X8_T} vi2x${ABC[C:C+8]} = ${VLD1_X8}(i2); i2 += 8; 230 ${XINT16X8_T} vsum${ABC[C:C+8]} = ${VADDL_X8}(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}); 231 232 $for M in range(2, ROW_TILE): 233 $for C in range(0, CHANNEL_TILE, 8): 234 $if M + 1 != ROW_TILE: 235 const ${XINT8X8_T} vi${M+1}x${ABC[C:C+8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 236 $else: 237 int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(buffer); buffer += 4; 238 int32x4_t vacc${ABC[C+4:C+8]} = vld1q_s32(buffer); buffer += 4; 239 vsum${ABC[C:C+8]} = ${VADDW_X8}(vsum${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]}); 240 241 $for C in range(0, CHANNEL_TILE, 8): 242 $if DATATYPE == "QS8": 243 vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vsum${ABC[C:C+8]})); 244 vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vsum${ABC[C:C+8]})); 245 $else: 246 vacc${ABC[C:C+4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C:C+4]}), vget_low_u16(vsum${ABC[C:C+8]}))); 247 vacc${ABC[C+4:C+8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[C+4:C+8]}), vget_high_u16(vsum${ABC[C:C+8]}))); 248 249 $if REQUANTIZATION == "FP32": 250 $for C in range(0, CHANNEL_TILE, 4): 251 float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]}); 252 253 $for C in range(0, CHANNEL_TILE, 4): 254 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale); 255 256 $if ARMV8: 257 $for C in range(0, CHANNEL_TILE, 4): 258 vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]}); 259 $else: 260 $for C in range(0, CHANNEL_TILE, 4): 261 vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias)); 262 263 $for C in range(0, CHANNEL_TILE, 4): 264 vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point); 265 $elif REQUANTIZATION == "RNDNU": 266 $for C in range(0, CHANNEL_TILE, 4): 267 vacc${ABC[C:C+4]} = vqshlq_s32(vacc${ABC[C:C+4]}, vleft_pre_shift); 268 269 $for C in range(0, CHANNEL_TILE, 4): 270 vacc${ABC[C:C+4]} = vqdmulhq_s32(vacc${ABC[C:C+4]}, vmultiplier); 271 272 $for C in range(0, CHANNEL_TILE, 4): 273 vacc${ABC[C:C+4]} = vrshlq_s32(vacc${ABC[C:C+4]}, vleft_post_shift); 274 275 #if XNN_ARCH_ARM64 276 $for C in range(0, CHANNEL_TILE, 8): 277 int16x8_t vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]}); 278 #else // !XNN_ARCH_ARM64 279 $for C in range(0, CHANNEL_TILE, 8): 280 int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]})); 281 #endif // !XNN_ARCH_ARM64 282 283 $if REQUANTIZATION != "FP32" or ARMV8: 284 $for C in range(0, CHANNEL_TILE, 8): 285 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); 286 287 #if XNN_ARCH_ARM64 288 $for C in range(0, CHANNEL_TILE, 16): 289 $if C + 8 < CHANNEL_TILE: 290 ${XINT8X16_T} vout${ABC[C:C+16]} = ${VQMOVXN_HIGH_S16}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]}); 291 $else: 292 ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]}); 293 #else // !XNN_ARCH_ARM64 294 $for C in range(0, CHANNEL_TILE, 16): 295 $if C + 8 < CHANNEL_TILE: 296 ${XINT8X16_T} vout${ABC[C:C+16]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), ${VQMOVXN_S16}(vacc${ABC[C+8:C+16]})); 297 $else: 298 ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]}); 299 #endif // !XNN_ARCH_ARM64 300 301 $for C in range(0, CHANNEL_TILE, 16): 302 $if C + 8 < CHANNEL_TILE: 303 vout${ABC[C:C+16]} = ${VMAXQ_X8}(vout${ABC[C:C+16]}, voutput_min); 304 $elif CHANNEL_TILE > 8: 305 vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_min)); 306 $else: 307 vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, voutput_min); 308 309 $for C in range(0, CHANNEL_TILE, 16): 310 $if C + 8 < CHANNEL_TILE: 311 vout${ABC[C:C+16]} = ${VMINQ_X8}(vout${ABC[C:C+16]}, voutput_max); 312 $elif CHANNEL_TILE > 8: 313 vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_max)); 314 $else: 315 vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, voutput_max); 316 317 $for C in range(0, CHANNEL_TILE, 16): 318 $if C + 8 < CHANNEL_TILE: 319 ${VST1Q_X8}(output, vout${ABC[C:C+16]}); output += 16; 320 $else: 321 ${VST1_X8}(output, vout${ABC[C:C+8]}); output += 8; 322 } 323 if XNN_UNLIKELY(channels != 0) { 324 ${"do " if CHANNEL_TILE > 8 else ""}{ 325 $for M in range(3): 326 $if CHANNEL_TILE > 8: 327 const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); i${M} += 8; 328 $else: 329 const ${XINT8X8_T} vi${M}x${ABC[0:8]} = ${VLD1_X8}(i${M}); 330 ${XINT16X8_T} vsum${ABC[0:8]} = ${VADDL_X8}(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}); 331 332 $for M in range(2, ROW_TILE): 333 $if M + 1 != ROW_TILE: 334 $if CHANNEL_TILE > 8: 335 const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); i${M+1} += 8; 336 $else: 337 const ${XINT8X8_T} vi${M+1}x${ABC[0:8]} = ${VLD1_X8}(i${M+1}); 338 $else: 339 int32x4_t vacc${ABC[0:4]} = vld1q_s32(buffer); buffer += 4; 340 int32x4_t vacc${ABC[4:8]} = vld1q_s32(buffer); buffer += 4; 341 vsum${ABC[0:8]} = ${VADDW_X8}(vsum${ABC[0:8]}, vi${M}x${ABC[0:8]}); 342 343 $if DATATYPE == "QS8": 344 vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vsum${ABC[0:8]})); 345 vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vsum${ABC[0:8]})); 346 $else: 347 vacc${ABC[0:4]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[0:4]}), vget_low_u16(vsum${ABC[0:8]}))); 348 vacc${ABC[4:8]} = vreinterpretq_s32_u32(vaddw_u16(vreinterpretq_u32_s32(vacc${ABC[4:8]}), vget_high_u16(vsum${ABC[0:8]}))); 349 350 $if REQUANTIZATION == "FP32": 351 float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]}); 352 float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]}); 353 354 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale); 355 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale); 356 357 $if ARMV8: 358 vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]}); 359 vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]}); 360 $else: 361 vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias)); 362 vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias)); 363 364 vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point); 365 vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point); 366 $elif REQUANTIZATION == "RNDNU": 367 vacc${ABC[0:4]} = vqshlq_s32(vacc${ABC[0:4]}, vleft_pre_shift); 368 vacc${ABC[4:8]} = vqshlq_s32(vacc${ABC[4:8]}, vleft_pre_shift); 369 370 vacc${ABC[0:4]} = vqdmulhq_s32(vacc${ABC[0:4]}, vmultiplier); 371 vacc${ABC[4:8]} = vqdmulhq_s32(vacc${ABC[4:8]}, vmultiplier); 372 373 vacc${ABC[0:4]} = vrshlq_s32(vacc${ABC[0:4]}, vleft_post_shift); 374 vacc${ABC[4:8]} = vrshlq_s32(vacc${ABC[4:8]}, vleft_post_shift); 375 376 #if XNN_ARCH_ARM64 377 int16x8_t vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]}); 378 #else 379 int16x8_t vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]})); 380 #endif 381 $if REQUANTIZATION != "FP32" or ARMV8: 382 vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point); 383 384 ${XINT8X8_T} vout${ABC[0:8]} = ${VQMOVXN_S16}(vacc${ABC[0:8]}); 385 $if CHANNEL_TILE > 8: 386 vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_min)); 387 vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_max)); 388 389 if XNN_LIKELY(channels >= 8) { 390 ${VST1_X8}(output, vout${ABC[0:8]}); output += 8; 391 channels -= 8; 392 } else { 393 if (channels & 4) { 394 vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4; 395 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 396 } 397 if (channels & 2) { 398 vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2; 399 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 400 } 401 if (channels & 1) { 402 ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0); output += 1; 403 } 404 channels = 0; 405 } 406 $else: 407 vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, voutput_min); 408 vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, voutput_max); 409 410 if (channels & 4) { 411 vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4; 412 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 413 } 414 if (channels & 2) { 415 vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2; 416 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 417 } 418 if (channels & 1) { 419 ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0); 420 } 421 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 422 } 423} 424