1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert SSE in [2, 3, 4] 7$assert not XOP or AVX 8$assert not AVX or SSE == 4 9$assert REQUANTIZATION == "FP32" 10$assert DATATYPE in ["QC8", "QS8", "QU8"] 11$assert VARIANT in ["LD64", "LD128"] 12$assert MR <= 4 13#include <assert.h> 14 15$if XOP: 16 #if defined(__GNUC__) || defined(__clang__) 17 #include <x86intrin.h> 18 #else 19 #include <immintrin.h> 20 #include <ammintrin.h> 21 #endif 22$else: 23 $SSE_HEADER = {2: "emmintrin.h", 3: "tmmintrin.h", 4: "smmintrin.h"}[SSE] 24 #include <${SSE_HEADER}> 25 26#include <xnnpack/igemm.h> 27#include <xnnpack/math.h> 28 29 30$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_%s_conv_minmax_params" % DATATYPE.lower() 31$PARAMS_STRUCT = ("" if DATATYPE == "QC8" else "fp32_") + ("sse4" if SSE >= 4 and DATATYPE != "QU8" else "sse2") 32$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 33$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE] 34void xnn_${DATATYPE.lower()}_igemm_minmax_fp32_ukernel_${MR}x4c8__${ISA}_${VARIANT.lower()}( 35 size_t mr, 36 size_t nc, 37 size_t kc, 38 size_t ks, 39 const ${XINT8_T}** restrict a, 40 const void* restrict w, 41 ${XINT8_T}* restrict c, 42 size_t cm_stride, 43 size_t cn_stride, 44 size_t a_offset, 45 const ${XINT8_T}* zero, 46 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 47{ 48 assert(mr != 0); 49 assert(mr <= ${MR}); 50 assert(nc != 0); 51 assert(kc != 0); 52 assert(ks != 0); 53 assert(ks % (${MR} * sizeof(void*)) == 0); 54 assert(a_offset % sizeof(${XINT8_T}) == 0); 55 assert(a != NULL); 56 assert(w != NULL); 57 assert(c != NULL); 58 59 kc = round_up_po2(kc, 8); 60 ${XINT8_T}* c0 = c; 61 $for M in range(1, MR): 62 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 63 $if M % 2 == 0: 64 if XNN_UNPREDICTABLE(mr <= ${M}) { 65 c${M} = c${M-1}; 66 } 67 $elif M + 1 == MR: 68 if XNN_UNPREDICTABLE(mr != ${M+1}) { 69 c${M} = c${M-1}; 70 } 71 $else: 72 if XNN_UNPREDICTABLE(mr < ${M+1}) { 73 c${M} = c${M-1}; 74 } 75 76 do { 77 $for N in range(4): 78 __m128i vacc0x${N} = _mm_cvtsi32_si128((int) ((const int32_t*) w)[${N}]); 79 $for M in range(1, MR): 80 $for N in range(4): 81 __m128i vacc${M}x${N} = vacc0x${N}; 82 w = (const void*) ((const int32_t*) w + 4); 83 84 size_t p = ks; 85 do { 86 $for M in range(MR): 87 const ${XINT8_T}* restrict a${M} = a[${M}]; 88 if XNN_UNPREDICTABLE(a${M} != zero) { 89 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset); 90 } 91 a += ${MR}; 92 93 size_t k = 0; 94 $if DATATYPE == "QU8": 95 const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point); 96 $if SSE < 4 or VARIANT == "LD128": 97 const __m128i vzero = _mm_setzero_si128(); 98 while (k < kc) { 99 $for M in range(MR): 100 const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M}); 101 $if DATATYPE == "QU8": 102 $if SSE == 4: 103 const __m128i vxa${M} = _mm_cvtepu8_epi16(va${M}); 104 $else: 105 const __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero); 106 $else: 107 $if SSE == 4: 108 const __m128i vxa${M} = _mm_cvtepi8_epi16(va${M}); 109 $else: 110 const __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8); 111 a${M} += 8; 112 113 $if VARIANT == "LD128": 114 $for N in range(0, 4, 2): 115 $if N == 0: 116 const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) w); 117 $else: 118 const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8})); 119 $if DATATYPE == "QU8": 120 const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}${N+1}, vzero), vb_zero_point); 121 const __m128i vxb${N+1} = _mm_sub_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vzero), vb_zero_point); 122 $elif SSE == 4: 123 const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N}${N+1}); 124 const __m128i vxb${N+1} = _mm_srai_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vb${N}${N+1}), 8); 125 $else: 126 const __m128i vsb${N}${N+1} = _mm_cmpgt_epi8(_mm_setzero_si128(), vb${N}${N+1}); 127 const __m128i vxb${N} = _mm_unpacklo_epi8(vb${N}${N+1}, vsb${N}${N+1}); 128 const __m128i vxb${N+1} = _mm_unpackhi_epi8(vb${N}${N+1}, vsb${N}${N+1}); 129 130 $for M in range(MR): 131 $if XOP: 132 vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N}); 133 vacc${M}x${N+1} = _mm_maddd_epi16(vxa${M}, vxb${N+1}, vacc${M}x${N+1}); 134 $else: 135 vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N})); 136 vacc${M}x${N+1} = _mm_add_epi32(vacc${M}x${N+1}, _mm_madd_epi16(vxa${M}, vxb${N+1})); 137 $else: 138 $for N in range(4): 139 $if N == 0: 140 const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) w); 141 $else: 142 const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8})); 143 $if DATATYPE == "QU8": 144 $if SSE == 4: 145 const __m128i vxb${N} = _mm_sub_epi16(_mm_cvtepu8_epi16(vb${N}), vb_zero_point); 146 $else: 147 const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}, vzero), vb_zero_point); 148 $else: 149 $if SSE == 4: 150 const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N}); 151 $else: 152 const __m128i vxb${N} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${N}, vb${N}), 8); 153 154 $for M in range(MR): 155 $if XOP: 156 vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N}); 157 $else: 158 vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N})); 159 160 w = (const void*) ((const ${XINT8_T}*) w + 32); 161 k += 8 * sizeof(${XINT8_T}); 162 } 163 p -= ${MR} * sizeof(void*); 164 } while (p != 0); 165 166 $if SSE >= 3: 167 $for M in range(MR): 168 const __m128i vacc${M}x01 = _mm_hadd_epi32(vacc${M}x0, vacc${M}x1); 169 const __m128i vacc${M}x23 = _mm_hadd_epi32(vacc${M}x2, vacc${M}x3); 170 171 $for M in range(MR): 172 __m128i vacc${M}x0123 = _mm_hadd_epi32(vacc${M}x01, vacc${M}x23); 173 $else: 174 $for M in range(MR): 175 const __m128i vacc${M}x02 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x0, vacc${M}x2), _mm_unpackhi_epi32(vacc${M}x0, vacc${M}x2)); 176 const __m128i vacc${M}x13 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x1, vacc${M}x3), _mm_unpackhi_epi32(vacc${M}x1, vacc${M}x3)); 177 178 $for M in range(MR): 179 __m128i vacc${M}x0123 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x02, vacc${M}x13), _mm_unpackhi_epi32(vacc${M}x02, vacc${M}x13)); 180 181 $for M in range(MR): 182 __m128 vscaled${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123); 183 184 $if DATATYPE == "QC8": 185 const __m128 vscale0123 = _mm_load_ps((const float*) w); 186 w = (const void*) ((const float*) w + 4); 187 $for M in range(MR): 188 vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale0123); 189 $else: 190 const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale); 191 $for M in range(MR): 192 vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale); 193 194 const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); 195 $for M in range(MR): 196 vscaled${M}x0123 = _mm_min_ps(vscaled${M}x0123, voutput_max_less_zero_point); 197 198 $for M in range(MR): 199 vacc${M}x0123 = _mm_cvtps_epi32(vscaled${M}x0123); 200 201 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point); 202 $for M in range(0, MR, 2): 203 __m128i vacc${M}${min(M+1, MR-1)}x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point); 204 205 $if DATATYPE == "QU8": 206 $if MR > 2: 207 __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123); 208 $else: 209 __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123); 210 211 vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min)); 212 $else: 213 $if SSE < 4: 214 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 215 $for M in range(0, MR, 2): 216 vacc${M}${min(M+1, MR-1)}x0123 = _mm_max_epi16(vacc${M}${min(M+1, MR-1)}x0123, voutput_min); 217 218 $if MR > 2: 219 __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123); 220 $else: 221 __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123); 222 223 $if SSE == 4: 224 vout = _mm_max_epi8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min)); 225 226 if (nc >= 4) { 227 $for M in reversed(range(1, MR)): 228 $if SSE == 4: 229 *((uint32_t*) c${M}) = (uint32_t) _mm_extract_epi32(vout, ${M}); 230 $else: 231 *((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(_mm_shuffle_epi32(vout, _MM_SHUFFLE(${M}, ${M}, ${M}, ${M}))); 232 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 233 *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); 234 c0 = (${XINT8_T}*) ((uintptr_t) c0 + cn_stride); 235 236 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 237 238 nc -= 4; 239 } else { 240 if (nc & 2) { 241 $for M in reversed(range(MR)): 242 *((uint16_t*) c${M}) = (uint16_t) _mm_extract_epi16(vout, ${M * 2}); 243 c${M} += 2; 244 vout = _mm_srli_epi32(vout, 16); 245 } 246 if (nc & 1) { 247 $if SSE == 4: 248 $for M in reversed(range(MR)): 249 *c${M} = (${XINT8_T}) _mm_extract_epi8(vout, ${M * 4}); 250 $else: 251 $for M in reversed(range(1, MR)): 252 *c${M} = (${XINT8_T}) _mm_extract_epi16(vout, ${M * 2}); 253 *c0 = (${XINT8_T}) _mm_cvtsi128_si32(vout); 254 } 255 256 nc = 0; 257 } 258 } while (nc != 0); 259} 260