1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert NR % 8 == 0 8$assert 8 <= NR <= 16 9$assert REQUANTIZATION in ["FP32", "RNDNU"] 10$assert DATATYPE in ["QC8", "QS8", "QU8"] 11$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32" 12#include <assert.h> 13 14#include <arm_neon.h> 15 16#include <xnnpack/common.h> 17#include <xnnpack/igemm.h> 18$if REQUANTIZATION == "FP32" and ARMV8: 19 #include <xnnpack/intrinsics-polyfill.h> 20 21 22$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" and ARMV8 else "neon") 23$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 24$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 25$XINT8X8_T = "uint8x8_t" if DATATYPE == "QU8" else "int8x8_t" 26$XINT8X16_T = "uint8x16_t" if DATATYPE == "QU8" else "int8x16_t" 27$VGET_LOW_X8 = "vget_low_u8" if DATATYPE == "QU8" else "vget_low_s8" 28$VGET_HIGH_X8 = "vget_high_u8" if DATATYPE == "QU8" else "vget_high_s8" 29$VCOMBINE_X8 = "vcombine_u8" if DATATYPE == "QU8" else "vcombine_s8" 30$VREINTERPRET_U32_X8 = "vreinterpret_u32_u8" if DATATYPE == "QU8" else "vreinterpret_u32_s8" 31$VREINTERPRETQ_U32_X8 = "vreinterpretq_u32_u8" if DATATYPE == "QU8" else "vreinterpretq_u32_s8" 32$VREINTERPRET_U16_X8 = "vreinterpret_u16_u8" if DATATYPE == "QU8" else "vreinterpret_u16_s8" 33$VREINTERPRETQ_U16_X8 = "vreinterpretq_u16_u8" if DATATYPE == "QU8" else "vreinterpretq_u16_s8" 34$VLD1_X8 = "vld1_u8" if DATATYPE == "QU8" else "vld1_s8" 35$VLD1_DUP_X8 = "vld1_dup_u8" if DATATYPE == "QU8" else "vld1_dup_s8" 36$VLD1Q_DUP_X8 = "vld1q_dup_u8" if DATATYPE == "QU8" else "vld1q_dup_s8" 37$VST1_X8 = "vst1_u8" if DATATYPE == "QU8" else "vst1_s8" 38$VST1Q_X8 = "vst1q_u8" if DATATYPE == "QU8" else "vst1q_s8" 39$VST1_LANE_X8 = "vst1_lane_u8" if DATATYPE == "QU8" else "vst1_lane_s8" 40$VST1Q_LANE_X8 = "vst1q_lane_u8" if DATATYPE == "QU8" else "vst1q_lane_s8" 41$VMIN_X8 = "vmin_u8" if DATATYPE == "QU8" else "vmin_s8" 42$VMAX_X8 = "vmax_u8" if DATATYPE == "QU8" else "vmax_s8" 43$VMINQ_X8 = "vminq_u8" if DATATYPE == "QU8" else "vminq_s8" 44$VMAXQ_X8 = "vmaxq_u8" if DATATYPE == "QU8" else "vmaxq_s8" 45$VEXT_X8 = "vext_u8" if DATATYPE == "QU8" else "vext_s8" 46$VEXTQ_X8 = "vextq_u8" if DATATYPE == "QU8" else "vextq_s8" 47$VQMOVXN_S16 = "vqmovun_s16" if DATATYPE == "QU8" else "vqmovn_s16" 48$VQMOVXN_HIGH_S16 = "vqmovun_high_s16" if DATATYPE == "QU8" else "vqmovn_high_s16" 49$ISA = "neonv8" if ARMV8 else "neon" 50void xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}__${ISA}_mlal_lane${"_prfm" if PREFETCH else ""}( 51 size_t mr, 52 size_t nc, 53 size_t kc, 54 size_t ks, 55 const ${XINT8_T}** restrict a, 56 const void* restrict w, 57 ${XINT8_T}* restrict c, 58 size_t cm_stride, 59 size_t cn_stride, 60 size_t a_offset, 61 const ${XINT8_T}* zero, 62 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 63{ 64 assert(mr != 0); 65 assert(mr <= ${MR}); 66 assert(nc != 0); 67 assert(kc != 0); 68 assert(ks != 0); 69 assert(ks % (${MR} * sizeof(void*)) == 0); 70 assert(a_offset % sizeof(${XINT8_T}) == 0); 71 assert(a != NULL); 72 assert(w != NULL); 73 assert(c != NULL); 74 75 ${XINT8_T}* c0 = c; 76 $for M in range(1, MR): 77 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 78 $if M % 2 == 0: 79 if XNN_UNPREDICTABLE(mr <= ${M}) { 80 c${M} = c${M-1}; 81 } 82 $elif M + 1 == MR: 83 if XNN_UNPREDICTABLE(mr != ${M+1}) { 84 c${M} = c${M-1}; 85 } 86 $else: 87 if XNN_UNPREDICTABLE(mr < ${M+1}) { 88 c${M} = c${M-1}; 89 } 90 91 $if DATATYPE == "QU8": 92 const uint8x8_t vb_zero_point = vld1_dup_u8(¶ms->${PARAMS_STRUCT}.kernel_zero_point[0]); 93 do { 94 $for N in range(0, NR, 4): 95 int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 96 $for M in range(1, MR): 97 $for N in range(0, NR, 4): 98 int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]}; 99 100 size_t p = ks; 101 do { 102 $for M in range(MR): 103 const ${XINT8_T}* restrict a${M} = a[${M}]; 104 if XNN_UNPREDICTABLE(a${M} != zero) { 105 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset); 106 } 107 a += ${MR}; 108 109 size_t k = kc; 110 while (k >= 8 * sizeof(${XINT8_T})) { 111 $for M in range(MR): 112 const ${XINT8X8_T} va${M} = ${VLD1_X8}(a${M}); a${M} += 8; 113 $if DATATYPE == "QU8": 114 const int16x8_t vxa${M} = vreinterpretq_s16_u16(vmovl_u8(va${M})); 115 $else: 116 const int16x8_t vxa${M} = vmovl_s8(va${M}); 117 118 $for K in range(4): 119 $for N in range(0, NR, 8): 120 const ${XINT8X8_T} vb${ABC[N:N+8]}c${K} = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 121 $if DATATYPE == "QU8": 122 const int16x8_t vxb${ABC[N:N+8]}c${K} = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c${K}, vb_zero_point)); 123 $else: 124 const int16x8_t vxb${ABC[N:N+8]}c${K} = vmovl_s8(vb${ABC[N:N+8]}c${K}); 125 126 $for M in range(MR): 127 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c${K}), vget_low_s16(vxa${M}), ${K}); 128 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c${K}), vget_low_s16(vxa${M}), ${K}); 129 130 $if PREFETCH: 131 $for N in range(0, NR, 8): 132 __builtin_prefetch((const ${XINT8_T}*) w + ${N * 8 + 448}); 133 134 $for K in range(4, 8): 135 $for N in range(0, NR, 8): 136 const ${XINT8X8_T} vb${ABC[N:N+8]}c${K} = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 137 $if DATATYPE == "QU8": 138 const int16x8_t vxb${ABC[N:N+8]}c${K} = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c${K}, vb_zero_point)); 139 $else: 140 const int16x8_t vxb${ABC[N:N+8]}c${K} = vmovl_s8(vb${ABC[N:N+8]}c${K}); 141 142 $for M in range(MR): 143 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c${K}), vget_high_s16(vxa${M}), ${K-4}); 144 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c${K}), vget_high_s16(vxa${M}), ${K-4}); 145 146 k -= 8 * sizeof(${XINT8_T}); 147 } 148 if XNN_UNLIKELY(k != 0) { 149 $for M in range(MR): 150 const ${XINT8X8_T} va${M} = ${VLD1_X8}(a${M}); a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + k); 151 $if DATATYPE == "QU8": 152 const int16x8_t vxa${M} = vreinterpretq_s16_u16(vmovl_u8(va${M})); 153 $else: 154 const int16x8_t vxa${M} = vmovl_s8(va${M}); 155 156 $for N in range(0, NR, 8): 157 const ${XINT8X8_T} vb${ABC[N:N+8]}c0 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 158 $if DATATYPE == "QU8": 159 const int16x8_t vxb${ABC[N:N+8]}c0 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c0, vb_zero_point)); 160 $else: 161 const int16x8_t vxb${ABC[N:N+8]}c0 = vmovl_s8(vb${ABC[N:N+8]}c0); 162 163 $for M in range(MR): 164 $for N in range(0, NR, 8): 165 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c0), vget_low_s16(vxa${M}), 0); 166 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c0), vget_low_s16(vxa${M}), 0); 167 168 if (k >= 2 * sizeof(${XINT8_T})) { 169 $for N in range(0, NR, 8): 170 const ${XINT8X8_T} vb${ABC[N:N+8]}c1 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 171 $if DATATYPE == "QU8": 172 const int16x8_t vxb${ABC[N:N+8]}c1 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c1, vb_zero_point)); 173 $else: 174 const int16x8_t vxb${ABC[N:N+8]}c1 = vmovl_s8(vb${ABC[N:N+8]}c1); 175 176 $for M in range(MR): 177 $for N in range(0, NR, 8): 178 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c1), vget_low_s16(vxa${M}), 1); 179 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c1), vget_low_s16(vxa${M}), 1); 180 181 if (k > 2 * sizeof(${XINT8_T})) { 182 $for N in range(0, NR, 8): 183 const ${XINT8X8_T} vb${ABC[N:N+8]}c2 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 184 $if DATATYPE == "QU8": 185 const int16x8_t vxb${ABC[N:N+8]}c2 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c2, vb_zero_point)); 186 $else: 187 const int16x8_t vxb${ABC[N:N+8]}c2 = vmovl_s8(vb${ABC[N:N+8]}c2); 188 189 $for M in range(MR): 190 $for N in range(0, NR, 8): 191 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c2), vget_low_s16(vxa${M}), 2); 192 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c2), vget_low_s16(vxa${M}), 2); 193 194 if (k >= 4 * sizeof(${XINT8_T})) { 195 $for N in range(0, NR, 8): 196 const ${XINT8X8_T} vb${ABC[N:N+8]}c3 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 197 $if DATATYPE == "QU8": 198 const int16x8_t vxb${ABC[N:N+8]}c3 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c3, vb_zero_point)); 199 $else: 200 const int16x8_t vxb${ABC[N:N+8]}c3 = vmovl_s8(vb${ABC[N:N+8]}c3); 201 202 $for M in range(MR): 203 $for N in range(0, NR, 8): 204 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c3), vget_low_s16(vxa${M}), 3); 205 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c3), vget_low_s16(vxa${M}), 3); 206 207 if (k > 4 * sizeof(${XINT8_T})) { 208 $for N in range(0, NR, 8): 209 const ${XINT8X8_T} vb${ABC[N:N+8]}c4 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 210 $if DATATYPE == "QU8": 211 const int16x8_t vxb${ABC[N:N+8]}c4 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c4, vb_zero_point)); 212 $else: 213 const int16x8_t vxb${ABC[N:N+8]}c4 = vmovl_s8(vb${ABC[N:N+8]}c4); 214 215 $for M in range(MR): 216 $for N in range(0, NR, 8): 217 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c4), vget_high_s16(vxa${M}), 0); 218 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c4), vget_high_s16(vxa${M}), 0); 219 220 if (k >= 6 * sizeof(${XINT8_T})) { 221 $for N in range(0, NR, 8): 222 const ${XINT8X8_T} vb${ABC[N:N+8]}c5 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 223 $if DATATYPE == "QU8": 224 const int16x8_t vxb${ABC[N:N+8]}c5 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c5, vb_zero_point)); 225 $else: 226 const int16x8_t vxb${ABC[N:N+8]}c5 = vmovl_s8(vb${ABC[N:N+8]}c5); 227 228 $for M in range(MR): 229 $for N in range(0, NR, 8): 230 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c5), vget_high_s16(vxa${M}), 1); 231 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c5), vget_high_s16(vxa${M}), 1); 232 233 if (k > 6 * sizeof(${XINT8_T})) { 234 $for N in range(0, NR, 8): 235 const ${XINT8X8_T} vb${ABC[N:N+8]}c6 = ${VLD1_X8}(w); w = (const void*) ((const ${XINT8_T}*) w + 8); 236 $if DATATYPE == "QU8": 237 const int16x8_t vxb${ABC[N:N+8]}c6 = vreinterpretq_s16_u16(vsubl_u8(vb${ABC[N:N+8]}c6, vb_zero_point)); 238 $else: 239 const int16x8_t vxb${ABC[N:N+8]}c6 = vmovl_s8(vb${ABC[N:N+8]}c6); 240 241 $for M in range(MR): 242 $for N in range(0, NR, 8): 243 vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c6), vget_high_s16(vxa${M}), 2); 244 vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c6), vget_high_s16(vxa${M}), 2); 245 } 246 } 247 } 248 } 249 } 250 } 251 } 252 p -= ${MR} * sizeof(void*); 253 } while (p != 0); 254 255 // Post-accumulation work 256 $if REQUANTIZATION == "RNDNU": 257 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_pre_shift); 258 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.multiplier); 259 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_post_shift); 260 261 $for M in range(MR): 262 $for N in range(0, NR, 4): 263 vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift); 264 265 $for M in range(MR): 266 $for N in range(0, NR, 4): 267 vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier); 268 269 $for M in range(MR): 270 $for N in range(0, NR, 4): 271 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift); 272 $elif REQUANTIZATION == "FP32": 273 $for M in range(MR): 274 $for N in range(0, NR, 4): 275 float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]}); 276 277 $if DATATYPE == "QC8": 278 $for N in range(0, NR, 4): 279 const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4); 280 $for M in range(MR): 281 vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]}); 282 $else: 283 const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); 284 $for M in range(MR): 285 $for N in range(0, NR, 4): 286 vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale); 287 288 $if ARMV8: 289 $for M in range(MR): 290 $for N in range(0, NR, 4): 291 vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]}); 292 $else: 293 const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.magic_bias); 294 $for M in range(MR): 295 $for N in range(0, NR, 4): 296 vacc${M}x${ABC[N:N+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${M}x${ABC[N:N+4]}, vmagic_bias)); 297 298 const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.magic_bias_less_output_zero_point); 299 $for M in range(MR): 300 $for N in range(0, NR, 4): 301 vacc${M}x${ABC[N:N+4]} = vqsubq_s32(vacc${M}x${ABC[N:N+4]}, vmagic_bias_less_output_zero_point); 302 303 $if REQUANTIZATION != "FP32" or ARMV8: 304 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->${PARAMS_STRUCT}.output_zero_point); 305#if XNN_ARCH_ARM64 306 $for M in range(MR): 307 $for N in range(0, NR, 8): 308 int16x8_t vacc${M}x${ABC[N:N+8]} = vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}); 309 310 $if REQUANTIZATION != "FP32" or ARMV8: 311 $for M in range(MR): 312 $for N in range(0, NR, 8): 313 vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point); 314 315 $for M in range(MR): 316 $for N in range(0, NR, 16): 317 $if N + 8 < NR: 318 ${XINT8X16_T} vout${M}x${ABC[N:N+16]} = ${VQMOVXN_HIGH_S16}(${VQMOVXN_S16}(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]}); 319 $elif M % 2 == 1: 320 ${XINT8X16_T} vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = ${VQMOVXN_HIGH_S16}(${VQMOVXN_S16}(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]}); 321 $elif M + 1 == MR: 322 ${XINT8X8_T} vout${M}x${ABC[N:N+8]} = ${VQMOVXN_S16}(vacc${M}x${ABC[N:N+8]}); 323#else 324 $for M in range(MR): 325 $for N in range(0, NR, 8): 326 int16x8_t vacc${M}x${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})); 327 328 $if REQUANTIZATION != "FP32" or ARMV8: 329 $for M in range(MR): 330 $for N in range(0, NR, 8): 331 vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point); 332 333 $for M in range(MR): 334 $for N in range(0, NR, 16): 335 $if N + 8 < NR: 336 ${XINT8X16_T} vout${M}x${ABC[N:N+16]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${M}x${ABC[N:N+8]}), ${VQMOVXN_S16}(vacc${M}x${ABC[N+8:N+16]})); 337 $elif M % 2 == 1: 338 ${XINT8X16_T} vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${M-1}x${ABC[N:N+8]}), ${VQMOVXN_S16}(vacc${M}x${ABC[N:N+8]})); 339 $elif M + 1 == MR: 340 ${XINT8X8_T} vout${M}x${ABC[N:N+8]} = ${VQMOVXN_S16}(vacc${M}x${ABC[N:N+8]}); 341#endif 342 343 $if NR == 8 and MR == 1: 344 const ${XINT8X8_T} voutput_min = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 345 $else: 346 const ${XINT8X16_T} voutput_min = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 347 $for M in range(MR): 348 $for N in range(0, NR, 16): 349 $if N + 8 < NR: 350 vout${M}x${ABC[N:N+16]} = ${VMAXQ_X8}(vout${M}x${ABC[N:N+16]}, voutput_min); 351 $elif M % 2 == 1: 352 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = ${VMAXQ_X8}(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min); 353 $elif M + 1 == MR: 354 $if NR == 8 and MR == 1: 355 vout${M}x${ABC[N:N+8]} = ${VMAX_X8}(vout${M}x${ABC[N:N+8]}, voutput_min); 356 $else: 357 vout${M}x${ABC[N:N+8]} = ${VMAX_X8}(vout${M}x${ABC[N:N+8]}, ${VGET_LOW_X8}(voutput_min)); 358 359 $if NR == 8 and MR == 1: 360 const ${XINT8X8_T} voutput_max = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 361 $else: 362 const ${XINT8X16_T} voutput_max = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 363 $for M in range(MR): 364 $for N in range(0, NR, 16): 365 $if N + 8 < NR: 366 vout${M}x${ABC[N:N+16]} = ${VMINQ_X8}(vout${M}x${ABC[N:N+16]}, voutput_max); 367 $elif M % 2 == 1: 368 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = ${VMINQ_X8}(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max); 369 $elif M + 1 == MR: 370 $if NR == 8 and MR == 1: 371 vout${M}x${ABC[N:N+8]} = ${VMIN_X8}(vout${M}x${ABC[N:N+8]}, voutput_max); 372 $else: 373 vout${M}x${ABC[N:N+8]} = ${VMIN_X8}(vout${M}x${ABC[N:N+8]}, ${VGET_LOW_X8}(voutput_max)); 374 375 if (nc >= ${NR}) { 376 $for M in reversed(range(MR)): 377 $for N in range(0, NR, 16): 378 $if N + 8 < NR: 379 ${VST1Q_X8}(c${M} + ${N}, vout${M}x${ABC[N:N+16]}); 380 $elif M % 2 == 1: 381 ${VST1_X8}(c${M} + ${N}, ${VGET_HIGH_X8}(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 382 ${VST1_X8}(c${M-1} + ${N}, ${VGET_LOW_X8}(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 383 $elif M + 1 == MR: 384 ${VST1_X8}(c${M} + ${N}, vout${M}x${ABC[N:N+8]}); 385 386 $for M in reversed(range(MR)): 387 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 388 389 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 390 391 nc -= ${NR}; 392 } else { 393 $if NR == 16: 394 $for M in reversed(range(MR)): 395 $if M % 2 == 1: 396 ${XINT8X16_T} vout${M-1}x01234567_${M}x01234567 = ${VCOMBINE_X8}(${VGET_LOW_X8}(vout${M-1}x0123456789ABCDEF), ${VGET_LOW_X8}(vout${M}x0123456789ABCDEF)); 397 $elif M + 1 == MR: 398 ${XINT8X8_T} vout${M}x01234567 = ${VGET_LOW_X8}(vout${M}x0123456789ABCDEF); 399 if (nc & 8) { 400 $for M in reversed(range(MR)): 401 $if M % 2 == 1: 402 ${VST1_X8}(c${M}, ${VGET_HIGH_X8}(vout${M-1}x01234567_${M}x01234567)); c${M} += 8; 403 ${VST1_X8}(c${M-1}, ${VGET_LOW_X8}(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8; 404 $elif M + 1 == MR: 405 ${VST1_X8}(c${M}, vout${M}x01234567); c${M} += 8; 406 $for M in reversed(range(MR)): 407 $if M % 2 == 1: 408 vout${M-1}x01234567_${M}x01234567 = ${VCOMBINE_X8}(${VGET_HIGH_X8}(vout${M-1}x0123456789ABCDEF), ${VGET_HIGH_X8}(vout${M}x0123456789ABCDEF)); 409 $elif M + 1 == MR: 410 vout${M}x01234567 = ${VGET_HIGH_X8}(vout${M}x0123456789ABCDEF); 411 } 412 if (nc & 4) { 413 $for M in reversed(range(MR)): 414 $if M % 2 == 1: 415 vst1q_lane_u32((void*) c${M}, ${VREINTERPRETQ_U32_X8}(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4; 416 vst1q_lane_u32((void*) c${M-1}, ${VREINTERPRETQ_U32_X8}(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4; 417 $elif M + 1 == MR: 418 vst1_lane_u32((void*) c${M}, ${VREINTERPRET_U32_X8}(vout${M}x01234567), 0); c${M} += 4; 419 $for M in reversed(range(MR)): 420 $if M % 2 == 1: 421 vout${M-1}x01234567_${M}x01234567 = ${VEXTQ_X8}(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4); 422 $elif M + 1 == MR: 423 vout${M}x01234567 = ${VEXT_X8}(vout${M}x01234567, vout${M}x01234567, 4); 424 } 425 if (nc & 2) { 426 $for M in reversed(range(MR)): 427 $if M % 2 == 1: 428 vst1q_lane_u16((void*) c${M}, ${VREINTERPRETQ_U16_X8}(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2; 429 vst1q_lane_u16((void*) c${M-1}, ${VREINTERPRETQ_U16_X8}(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2; 430 $elif M + 1 == MR: 431 vst1_lane_u16((void*) c${M}, ${VREINTERPRET_U16_X8}(vout${M}x01234567), 0); c${M} += 2; 432 $for M in reversed(range(MR)): 433 $if M % 2 == 1: 434 vout${M-1}x01234567_${M}x01234567 = ${VEXTQ_X8}(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2); 435 $elif M + 1 == MR: 436 vout${M}x01234567 = ${VEXT_X8}(vout${M}x01234567, vout${M}x01234567, 2); 437 } 438 if (nc & 1) { 439 $for M in reversed(range(MR)): 440 $if M % 2 == 1: 441 ${VST1Q_LANE_X8}(c${M}, vout${M-1}x01234567_${M}x01234567, 8); 442 ${VST1Q_LANE_X8}(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0); 443 $elif M + 1 == MR: 444 ${VST1_LANE_X8}(c${M}, vout${M}x01234567, 0); 445 } 446 447 nc = 0; 448 } 449 } while (nc != 0); 450} 451