1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert SSE in [2, 4] 7$assert DATATYPE in ["QS8", "QU8"] 8$assert BATCH_TILE % 8 == 0 9$assert BATCH_TILE >= 8 10$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] 11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 12#include <assert.h> 13 14#include <${SSE_HEADER}> 15 16#include <xnnpack/common.h> 17#include <xnnpack/vcvt.h> 18 19 20$ISA = {2: "sse2", 4: "sse41"}[SSE] 21$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 22$_MM_PACKXS_EPI16 = {"QS8": "_mm_packs_epi16", "QU8": "_mm_packus_epi16"}[DATATYPE] 23$_MM_MAX_EPX8 = {"QS8": "_mm_max_epi8", "QU8": "_mm_max_epu8"}[DATATYPE] 24void xnn_f32_${DATATYPE.lower()}_vcvt_ukernel__${ISA}_x${BATCH_TILE}( 25 size_t n, 26 const float* x, 27 ${XINT8_T}* y, 28 const union xnn_f32_${DATATYPE.lower()}_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 29{ 30 assert(n != 0); 31 assert(n % sizeof(float) == 0); 32 assert(x != NULL); 33 assert(y != NULL); 34 35 const __m128 vscale = _mm_load_ps(params->sse${SSE}.scale); 36 const __m128 voutput_max_less_zero_point = _mm_load_ps(params->sse${SSE}.output_max_less_zero_point); 37 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse${SSE}.output_zero_point); 38 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse${SSE}.output_min); 39 40 $if BATCH_TILE > 8: 41 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 42 __m128 vx${ABC[0:4]} = _mm_loadu_ps(x); 43 $for N in range(4, BATCH_TILE, 4): 44 __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N}); 45 x += ${BATCH_TILE}; 46 47 $for N in range(0, BATCH_TILE, 4): 48 vx${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vscale); 49 50 $for N in range(0, BATCH_TILE, 4): 51 vx${ABC[N:N+4]} = _mm_min_ps(vx${ABC[N:N+4]}, voutput_max_less_zero_point); 52 53 $for N in range(0, BATCH_TILE, 4): 54 const __m128i vy${ABC[N:N+4]} = _mm_cvtps_epi32(vx${ABC[N:N+4]}); 55 56 $for N in range(0, BATCH_TILE, 8): 57 __m128i vy${ABC[N:N+8]} = _mm_packs_epi32(vy${ABC[N:N+4]}, vy${ABC[N+4:N+8]}); 58 59 $for N in range(0, BATCH_TILE, 8): 60 vy${ABC[N:N+8]} = _mm_adds_epi16(vy${ABC[N:N+8]}, voutput_zero_point); 61 62 $if DATATYPE == "QS8" and SSE < 4: 63 $for N in range(0, BATCH_TILE, 8): 64 vy${ABC[N:N+8]} = _mm_max_epi16(vy${ABC[N:N+8]}, voutput_min); 65 66 $for N in range(0, BATCH_TILE, 16): 67 $if N + 8 < BATCH_TILE: 68 __m128i vy${ABC[N:N+16]} = ${_MM_PACKXS_EPI16}(vy${ABC[N:N+8]}, vy${ABC[N+8:N+16]}); 69 $else: 70 vy${ABC[N:N+8]} = ${_MM_PACKXS_EPI16}(vy${ABC[N:N+8]}, vy${ABC[N:N+8]}); 71 72 $if DATATYPE == "QU8" or SSE == 4: 73 $for N in range(0, BATCH_TILE, 16): 74 $if N + 8 < BATCH_TILE: 75 vy${ABC[N:N+16]} = ${_MM_MAX_EPX8}(vy${ABC[N:N+16]}, voutput_min); 76 $else: 77 vy${ABC[N:N+8]} = ${_MM_MAX_EPX8}(vy${ABC[N:N+8]}, voutput_min); 78 79 _mm_storeu_si128((__m128i*) y, vy${ABC[0:16]}); 80 $for N in range(16, BATCH_TILE, 16): 81 $if N + 8 < BATCH_TILE: 82 _mm_storeu_si128((__m128i*) (y + ${N}), vy${ABC[N:N+16]}); 83 $else: 84 _mm_storel_epi64((__m128i*) (y + ${N}), vy${ABC[N:N+8]}); 85 y += ${BATCH_TILE}; 86 } 87 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 88 __m128 vx_lo = _mm_loadu_ps(x); 89 __m128 vx_hi = _mm_loadu_ps(x + 4); 90 x += 8; 91 92 vx_lo = _mm_mul_ps(vx_lo, vscale); 93 vx_hi = _mm_mul_ps(vx_hi, vscale); 94 95 vx_lo = _mm_min_ps(vx_lo, voutput_max_less_zero_point); 96 vx_hi = _mm_min_ps(vx_hi, voutput_max_less_zero_point); 97 98 const __m128i vy_lo = _mm_cvtps_epi32(vx_lo); 99 const __m128i vy_hi = _mm_cvtps_epi32(vx_hi); 100 101 __m128i vy = _mm_packs_epi32(vy_lo, vy_hi); 102 vy = _mm_adds_epi16(vy, voutput_zero_point); 103 $if DATATYPE == "QS8" and SSE < 4: 104 vy = _mm_max_epi16(vy, voutput_min); 105 vy = ${_MM_PACKXS_EPI16}(vy, vy); 106 $if DATATYPE == "QU8" or SSE == 4: 107 vy = ${_MM_MAX_EPX8}(vy, voutput_min); 108 109 _mm_storel_epi64((__m128i*) y, vy); 110 y += 8; 111 } 112 if XNN_UNLIKELY(n != 0) { 113 __m128 vx_lo = _mm_loadu_ps(x); 114 const float* x_hi = (const float*) ((uintptr_t) x + (n & (4 * sizeof(float)))); 115 __m128 vx_hi = _mm_loadu_ps(x_hi); 116 117 vx_lo = _mm_mul_ps(vx_lo, vscale); 118 vx_hi = _mm_mul_ps(vx_hi, vscale); 119 120 vx_lo = _mm_min_ps(vx_lo, voutput_max_less_zero_point); 121 vx_hi = _mm_min_ps(vx_hi, voutput_max_less_zero_point); 122 123 const __m128i vy_lo = _mm_cvtps_epi32(vx_lo); 124 const __m128i vy_hi = _mm_cvtps_epi32(vx_hi); 125 126 __m128i vy = _mm_packs_epi32(vy_lo, vy_hi); 127 vy = _mm_adds_epi16(vy, voutput_zero_point); 128 $if DATATYPE == "QS8" and SSE < 4: 129 vy = _mm_max_epi16(vy, voutput_min); 130 vy = ${_MM_PACKXS_EPI16}(vy, vy); 131 $if DATATYPE == "QU8" or SSE == 4: 132 vy = ${_MM_MAX_EPX8}(vy, voutput_min); 133 134 if (n & (4 * sizeof(float))) { 135 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vy); 136 y += 4; 137 vy = _mm_srli_epi64(vy, 32); 138 } 139 $if SSE == 4: 140 if (n & (2 * sizeof(float))) { 141 *((uint16_t*) y) = (uint16_t) _mm_extract_epi16(vy, 0); 142 y += 2; 143 vy = _mm_srli_epi32(vy, 16); 144 } 145 if (n & (1 * sizeof(float))) { 146 *y = (${XINT8_T}) _mm_extract_epi8(vy, 0); 147 } 148 $else: 149 { 150 uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy); 151 if (n & (2 * sizeof(float))) { 152 *((uint16_t*) y) = (uint16_t) vy_lo; 153 y += 2; 154 vy_lo >>= 16; 155 } 156 if (n & (1 * sizeof(float))) { 157 *y = (${XINT8_T}) vy_lo; 158 } 159 } 160 } 161} 162