1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert NR % 8 == 0 8$assert 8 <= NR <= 16 9#include <assert.h> 10 11#include <arm_neon.h> 12 13#include <xnnpack/gemm.h> 14#include <xnnpack/math.h> 15 16 17void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( 18 size_t mr, 19 size_t nc, 20 size_t kc, 21 const int8_t* restrict a, 22 size_t a_stride, 23 const void* restrict w, 24 int8_t* restrict c, 25 size_t cm_stride, 26 size_t cn_stride, 27 const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN 28{ 29 assert(mr != 0); 30 assert(mr <= ${MR}); 31 assert(nc != 0); 32 assert(kc != 0); 33 assert(kc % sizeof(int8_t) == 0); 34 assert(a != NULL); 35 assert(w != NULL); 36 assert(c != NULL); 37 38 kc = round_up_po2(kc, 16); 39 const int8_t* a0 = a; 40 int8_t* c0 = c; 41 $for M in range(1, MR): 42 const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride); 43 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); 44 $if M % 2 == 0: 45 if XNN_UNPREDICTABLE(mr <= ${M}) { 46 a${M} = a${M-1}; 47 c${M} = c${M-1}; 48 } 49 $elif M + 1 == MR: 50 if XNN_UNPREDICTABLE(mr != ${M+1}) { 51 a${M} = a${M-1}; 52 c${M} = c${M-1}; 53 } 54 $else: 55 if XNN_UNPREDICTABLE(mr < ${M+1}) { 56 a${M} = a${M-1}; 57 c${M} = c${M-1}; 58 } 59 60 do { 61 $for N in range(NR): 62 int32x4_t vacc0x${N} = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); 63 $for M in range(1, MR): 64 $for N in range(NR): 65 int32x4_t vacc${M}x${N} = vacc0x${N}; 66 67 // KC loop of 16 68 size_t k = 0; 69 while (k < kc) { 70 $for M in range(MR): 71 const int8x16_t va${M} = vld1q_s8(a${M}); a${M} += 16; 72 73 $for N in range(NR): 74 const int8x16_t vb${N} = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t)); 75 76 $for N in range(NR): 77 $for M in range(MR): 78 int16x8_t vprod${M}x${N} = vmull_s8(vget_low_s8(vb${N}), vget_low_s8(va${M})); 79 $for M in range(MR): 80 vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vget_high_s8(vb${N}), vget_high_s8(va${M})); 81 $for M in range(MR): 82 vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); 83 84 k += 16 * sizeof(int8_t); 85 } 86 87#if XNN_ARCH_ARM64 88 $for M in range(MR): 89 $for N in range(0, NR, 4): 90 const int32x4_t vsum${M}x${ABC[N:N+2]} = vpaddq_s32(vacc${M}x${N}, vacc${M}x${N+1}); 91 const int32x4_t vsum${M}x${ABC[N+2:N+4]} = vpaddq_s32(vacc${M}x${N+2}, vacc${M}x${N+3}); 92 $for M in range(MR): 93 $for N in range(0, NR, 4): 94 int32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]}); 95#else 96 $for M in range(MR): 97 $for N in range(0, NR, 4): 98 const int32x2_t vpsum${M}x${ABC[N]} = vadd_s32(vget_low_s32(vacc${M}x${N}), vget_high_s32(vacc${M}x${N})); 99 const int32x2_t vpsum${M}x${ABC[N+1]} = vadd_s32(vget_low_s32(vacc${M}x${N+1}), vget_high_s32(vacc${M}x${N+1})); 100 const int32x2_t vpsum${M}x${ABC[N+2]} = vadd_s32(vget_low_s32(vacc${M}x${N+2}), vget_high_s32(vacc${M}x${N+2})); 101 const int32x2_t vpsum${M}x${ABC[N+3]} = vadd_s32(vget_low_s32(vacc${M}x${N+3}), vget_high_s32(vacc${M}x${N+3})); 102 const int32x2_t vsum${M}x${ABC[N:N+2]} = vpadd_s32(vpsum${M}x${ABC[N]}, vpsum${M}x${ABC[N+1]}); 103 const int32x2_t vsum${M}x${ABC[N+2:N+4]} = vpadd_s32(vpsum${M}x${ABC[N+2]}, vpsum${M}x${ABC[N+3]}); 104 int32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]} ); 105#endif 106 107 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); 108 $for M in range(MR): 109 $for N in range(0, NR, 4): 110 vacc${M}x${ABC[N:N+4]} = vqrdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier); 111 112 const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); 113 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); 114 $for M in range(MR): 115 $for N in range(0, NR, 4): 116 vacc${M}x${ABC[N:N+4]} = vsraq_n_s32(vacc${M}x${ABC[N:N+4]}, vbicq_s32(vacc${M}x${ABC[N:N+4]}, vzero_shift_mask), 31); 117 118 $for M in range(MR): 119 $for N in range(0, NR, 4): 120 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_shift); 121 122 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); 123#if XNN_ARCH_ARM64 124 $for M in range(MR): 125 $for N in range(0, NR, 8): 126 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point); 127 $for M in range(MR): 128 $for N in range(0, NR, 16): 129 $if N + 8 < NR: 130 int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]}); 131 $elif M % 2 == 1: 132 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]}); 133 $elif M + 1 == MR: 134 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 135#else 136 $for M in range(MR): 137 $for N in range(0, NR, 8): 138 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point); 139 140 $for M in range(MR): 141 $for N in range(0, NR, 16): 142 $if N + 8 < NR: 143 int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]})); 144 $elif M % 2 == 1: 145 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]})); 146 $elif M + 1 == MR: 147 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 148#endif 149 $if NR == 8 and MR == 1: 150 const int8x8_t voutput_min = vld1_dup_s8(¶ms->neon.output_min); 151 const int8x8_t voutput_max = vld1_dup_s8(¶ms->neon.output_max); 152 $else: 153 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); 154 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); 155 156 $for M in range(MR): 157 $for N in range(0, NR, 16): 158 $if N + 8 < NR: 159 vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min); 160 $elif M % 2 == 1: 161 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min); 162 $elif M + 1 == MR: 163 $if NR == 8 and MR == 1: 164 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min); 165 $else: 166 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min)); 167 168 $for M in range(MR): 169 $for N in range(0, NR, 16): 170 $if N + 8 < NR: 171 vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max); 172 $elif M % 2 == 1: 173 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max); 174 $elif M + 1 == MR: 175 $if NR == 8 and MR == 1: 176 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max); 177 $else: 178 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max)); 179 180 if (nc >= ${NR}) { 181 $for M in range(MR): 182 $for N in range(0, NR, 16): 183 $if N + 8 < NR: 184 vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]}); 185 $elif M % 2 == 1: 186 vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 187 vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 188 $elif M + 1 == MR: 189 vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]}); 190 191 $for M in range(MR): 192 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); 193 194 $for M in range(MR): 195 a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); 196 197 nc -= ${NR}; 198 } else { 199 $if NR == 16: 200 $for M in range(MR): 201 $if M % 2 == 1: 202 int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF)); 203 $elif M + 1 == MR: 204 int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF); 205 if (nc & 8) { 206 $for M in range(MR): 207 $if M % 2 == 1: 208 vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M-1} += 8; 209 vst1_s8(c${M}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M} += 8; 210 $elif M + 1 == MR: 211 vst1_s8(c${M}, vout${M}x${ABC[N:N+8]}); c${M} += 8; 212 $for M in range(MR): 213 $if M % 2 == 1: 214 vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF)); 215 $elif M + 1 == MR: 216 vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF); 217 } 218 if (nc & 4) { 219 $for M in range(MR): 220 $if M % 2 == 1: 221 vst1q_lane_u32(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4; 222 vst1q_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4; 223 $elif M + 1 == MR: 224 vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4; 225 $for M in range(MR): 226 $if M % 2 == 1: 227 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4); 228 $elif M + 1 == MR: 229 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4); 230 } 231 if (nc & 2) { 232 $for M in range(MR): 233 $if M % 2 == 1: 234 vst1q_lane_u16(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2; 235 vst1q_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2; 236 $elif M + 1 == MR: 237 vst1_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2; 238 $for M in range(MR): 239 $if M % 2 == 1: 240 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2); 241 $elif M + 1 == MR: 242 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2); 243 } 244 if (nc & 1) { 245 $for M in range(MR): 246 $if M % 2 == 1: 247 vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0); 248 vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8); 249 $elif M + 1 == MR: 250 vst1_lane_s8(c${M}, vout${M}x01234567, 0); 251 } 252 253 nc = 0; 254 } 255 } while (nc != 0); 256} 257