1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert REQUANTIZATION == "FP32" 8$assert SSE in [2, 4] 9$assert not AVX or SSE == 4 10$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] 11$assert BATCH_TILE % 8 == 0 12$assert BATCH_TILE >= 8 13$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 14#include <assert.h> 15 16#include <${SSE_HEADER}> 17 18#include <xnnpack/vmul.h> 19 20 21$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("sse4" if SSE == 4 and DATATYPE == "QS8" else "sse2") 22$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 23$_MM_CVTEPX8_EPI16 = {"QS8": "_mm_cvtepi8_epi16", "QU8": "_mm_cvtepu8_epi16"}[DATATYPE] 24$_MM_PACKXS_EPI16 = {"QS8": "_mm_packs_epi16", "QU8": "_mm_packus_epi16"}[DATATYPE] 25$_MM_MIN_EPX8 = {"QS8": "_mm_min_epi8", "QU8": "_mm_min_epu8"}[DATATYPE] 26$_MM_MAX_EPX8 = {"QS8": "_mm_max_epi8", "QU8": "_mm_max_epu8"}[DATATYPE] 27$ISA = "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE] 28void xnn_${DATATYPE.lower()}_vmul_minmax_${REQUANTIZATION.lower()}_ukernel__${ISA}_mul16_ld64_x${BATCH_TILE}( 29 size_t n, 30 const ${XINT8_T}* input_a, 31 const ${XINT8_T}* input_b, 32 ${XINT8_T}* output, 33 const union xnn_${DATATYPE.lower()}_mul_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 34 35{ 36 const __m128i va_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.a_zero_point); 37 const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.b_zero_point); 38 const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale); 39 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point); 40 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 41 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_max); 42 43 for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) { 44 $if SSE == 4: 45 const __m128i va${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_a)); 46 const __m128i vb${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_b)); 47 $for N in range(8, BATCH_TILE, 8): 48 const __m128i va${ABC[N:N+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (input_a + ${N}))); 49 const __m128i vb${ABC[N:N+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (input_b + ${N}))); 50 $else: 51 __m128i va${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_a); 52 __m128i vb${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_b); 53 $for N in range(8, BATCH_TILE, 8): 54 __m128i va${ABC[N:N+8]} = _mm_loadl_epi64((const __m128i*) (input_a + ${N})); 55 __m128i vb${ABC[N:N+8]} = _mm_loadl_epi64((const __m128i*) (input_b + ${N})); 56 input_a += ${BATCH_TILE}; 57 input_b += ${BATCH_TILE}; 58 59 $if SSE < 4: 60 $if DATATYPE == "QU8": 61 const __m128i vzero = _mm_setzero_si128(); 62 $for N in range(0, BATCH_TILE, 8): 63 va${ABC[N:N+8]} = _mm_unpacklo_epi8(va${ABC[N:N+8]}, vzero); 64 vb${ABC[N:N+8]} = _mm_unpacklo_epi8(vb${ABC[N:N+8]}, vzero); 65 $else: 66 $for N in range(0, BATCH_TILE, 8): 67 va${ABC[N:N+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(va${ABC[N:N+8]}, va${ABC[N:N+8]}), 8); 68 vb${ABC[N:N+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${ABC[N:N+8]}, vb${ABC[N:N+8]}), 8); 69 70 $for N in range(0, BATCH_TILE, 8): 71 const __m128i vxa${ABC[N:N+8]} = _mm_sub_epi16(va${ABC[N:N+8]}, va_zero_point); 72 const __m128i vxb${ABC[N:N+8]} = _mm_sub_epi16(vb${ABC[N:N+8]}, vb_zero_point); 73 74 $for N in range(0, BATCH_TILE, 8): 75 const __m128i vprod${ABC[N:N+8]}lo = _mm_mullo_epi16(vxa${ABC[N:N+8]}, vxb${ABC[N:N+8]}); 76 const __m128i vprod${ABC[N:N+8]}hi = _mm_mulhi_epi16(vxa${ABC[N:N+8]}, vxb${ABC[N:N+8]}); 77 78 $for N in range(0, BATCH_TILE, 8): 79 const __m128i vprod${ABC[N:N+4]} = _mm_unpacklo_epi16(vprod${ABC[N:N+8]}lo, vprod${ABC[N:N+8]}hi); 80 const __m128i vprod${ABC[N+4:N+8]} = _mm_unpackhi_epi16(vprod${ABC[N:N+8]}lo, vprod${ABC[N:N+8]}hi); 81 82 $for N in range(0, BATCH_TILE, 4): 83 __m128 vfpacc${ABC[N:N+4]} = _mm_cvtepi32_ps(vprod${ABC[N:N+4]}); 84 85 $for N in range(0, BATCH_TILE, 4): 86 vfpacc${ABC[N:N+4]} = _mm_mul_ps(vfpacc${ABC[N:N+4]}, vscale); 87 88 $for N in range(0, BATCH_TILE, 4): 89 const __m128i vacc${ABC[N:N+4]} = _mm_cvtps_epi32(vfpacc${ABC[N:N+4]}); 90 91 $for N in range(0, BATCH_TILE, 8): 92 __m128i vout${ABC[N:N+8]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[N:N+4]}, vacc${ABC[N+4:N+8]}), voutput_zero_point); 93 94 $if DATATYPE == "QS8" and SSE < 4: 95 $for N in range(0, BATCH_TILE, 8): 96 vout${ABC[N:N+8]} = _mm_max_epi16(vout${ABC[N:N+8]}, voutput_min); 97 98 $for N in range(0, BATCH_TILE, 8): 99 vout${ABC[N:N+8]} = _mm_min_epi16(vout${ABC[N:N+8]}, voutput_max); 100 101 $for N in range(0, BATCH_TILE, 16): 102 $if N + 8 < BATCH_TILE: 103 __m128i vout${ABC[N:N+16]} = ${_MM_PACKXS_EPI16}(vout${ABC[N:N+8]}, vout${ABC[N+8:N+16]}); 104 $else: 105 __m128i vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_PACKXS_EPI16}(vout${ABC[N:N+8]}, vout${ABC[N:N+8]}); 106 107 $if DATATYPE == "QU8" or SSE == 4: 108 $for N in range(0, BATCH_TILE, 16): 109 $if N + 8 < BATCH_TILE: 110 vout${ABC[N:N+16]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+16]}, voutput_min); 111 $else: 112 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_min); 113 114 $for N in range(0, BATCH_TILE, 16): 115 $if N + 8 < BATCH_TILE: 116 vout${ABC[N:N+16]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+16]}, voutput_max); 117 $else: 118 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_max); 119 120 $if BATCH_TILE >= 16: 121 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 122 $else: 123 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 124 $for N in range(16, BATCH_TILE, 16): 125 $if N + 8 < BATCH_TILE: 126 _mm_storeu_si128((__m128i*) (output + ${N}), vout${ABC[N:N+16]}); 127 $else: 128 _mm_storel_epi64((__m128i*) (output + ${N}), vout${ABC[N:N+8]}${ABC[N:N+8]}); 129 output += ${BATCH_TILE}; 130 } 131 if XNN_UNLIKELY(n != 0) { 132 ${"do " if BATCH_TILE > 8 else ""}{ 133 $if SSE == 4: 134 const __m128i va${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_a)); 135 const __m128i vb${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_b)); 136 $else: 137 __m128i va${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_a); 138 __m128i vb${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_b); 139 $if BATCH_TILE > 8: 140 input_a += 8; 141 input_b += 8; 142 143 $if SSE < 4: 144 $if DATATYPE == "QU8": 145 const __m128i vzero = _mm_setzero_si128(); 146 va${ABC[0:8]} = _mm_unpacklo_epi8(va${ABC[0:8]}, vzero); 147 vb${ABC[0:8]} = _mm_unpacklo_epi8(vb${ABC[0:8]}, vzero); 148 $else: 149 va${ABC[0:8]} = _mm_srai_epi16(_mm_unpacklo_epi8(va${ABC[0:8]}, va${ABC[0:8]}), 8); 150 vb${ABC[0:8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${ABC[0:8]}, vb${ABC[0:8]}), 8); 151 152 const __m128i vxa${ABC[0:8]} = _mm_sub_epi16(va${ABC[0:8]}, va_zero_point); 153 const __m128i vxb${ABC[0:8]} = _mm_sub_epi16(vb${ABC[0:8]}, vb_zero_point); 154 155 const __m128i vprod${ABC[0:8]}lo = _mm_mullo_epi16(vxa${ABC[0:8]}, vxb${ABC[0:8]}); 156 const __m128i vprod${ABC[0:8]}hi = _mm_mulhi_epi16(vxa${ABC[0:8]}, vxb${ABC[0:8]}); 157 158 const __m128i vprod${ABC[0:4]} = _mm_unpacklo_epi16(vprod${ABC[0:8]}lo, vprod${ABC[0:8]}hi); 159 const __m128i vprod${ABC[4:8]} = _mm_unpackhi_epi16(vprod${ABC[0:8]}lo, vprod${ABC[0:8]}hi); 160 161 __m128 vfpacc${ABC[0:4]} = _mm_cvtepi32_ps(vprod${ABC[0:4]}); 162 __m128 vfpacc${ABC[4:8]} = _mm_cvtepi32_ps(vprod${ABC[4:8]}); 163 164 vfpacc${ABC[0:4]} = _mm_mul_ps(vfpacc${ABC[0:4]}, vscale); 165 vfpacc${ABC[4:8]} = _mm_mul_ps(vfpacc${ABC[4:8]}, vscale); 166 167 const __m128i vacc${ABC[0:4]} = _mm_cvtps_epi32(vfpacc${ABC[0:4]}); 168 const __m128i vacc${ABC[4:8]} = _mm_cvtps_epi32(vfpacc${ABC[4:8]}); 169 170 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[0:4]}, vacc${ABC[4:8]}), voutput_zero_point); 171 $if DATATYPE == "QS8" and SSE < 4: 172 vout${ABC[0:8]} = _mm_max_epi16(vout${ABC[0:8]}, voutput_min); 173 vout${ABC[0:8]} = _mm_min_epi16(vout${ABC[0:8]}, voutput_max); 174 175 __m128i vout${ABC[0:8]}${ABC[0:8]} = ${_MM_PACKXS_EPI16}(vout${ABC[0:8]}, vout${ABC[0:8]}); 176 $if DATATYPE == "QU8" or SSE == 4: 177 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MAX_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_min); 178 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MIN_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_max); 179 180 $if BATCH_TILE > 8: 181 if XNN_LIKELY(n >= (8 * sizeof(${XINT8_T}))) { 182 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 183 output += 8; 184 n -= 8 * sizeof(${XINT8_T}); 185 } else { 186 if (n & (4 * sizeof(${XINT8_T}))) { 187 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 188 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 189 output += 4; 190 } 191 if (n & (2 * sizeof(${XINT8_T}))) { 192 $if SSE == 4: 193 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0); 194 $else: 195 *((uint16_t*) output) = (uint16_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 196 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 197 output += 2; 198 } 199 if (n & (1 * sizeof(${XINT8_T}))) { 200 $if SSE == 4: 201 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 202 $else: 203 *output = (int32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 204 } 205 n = 0; 206 } 207 $else: 208 if (n & (4 * sizeof(${XINT8_T}))) { 209 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 210 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 211 output += 4; 212 } 213 if (n & (2 * sizeof(${XINT8_T}))) { 214 $if SSE == 4: 215 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0); 216 $else: 217 *((uint16_t*) output) = (uint16_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 218 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 219 output += 2; 220 } 221 if (n & (1 * sizeof(${XINT8_T}))) { 222 $if SSE == 4: 223 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 224 $else: 225 *output = (${XINT8_T}) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 226 } 227 }${" while (n != 0);" if BATCH_TILE > 8 else ""} 228 } 229} 230