• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert DATATYPE in ["QS8", "QU8"]
8$assert BATCH_TILE % 8 == 0
9$assert BATCH_TILE >= 8
10$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
12#include <assert.h>
13
14#include <${SSE_HEADER}>
15
16#include <xnnpack/common.h>
17#include <xnnpack/vcvt.h>
18
19
20$ISA = {2: "sse2", 4: "sse41"}[SSE]
21$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
22$_MM_PACKXS_EPI16 = {"QS8": "_mm_packs_epi16", "QU8": "_mm_packus_epi16"}[DATATYPE]
23$_MM_MAX_EPX8 = {"QS8": "_mm_max_epi8", "QU8": "_mm_max_epu8"}[DATATYPE]
24void xnn_f32_${DATATYPE.lower()}_vcvt_ukernel__${ISA}_x${BATCH_TILE}(
25    size_t n,
26    const float* x,
27    ${XINT8_T}* y,
28    const union xnn_f32_${DATATYPE.lower()}_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29{
30  assert(n != 0);
31  assert(n % sizeof(float) == 0);
32  assert(x != NULL);
33  assert(y != NULL);
34
35  const __m128 vscale = _mm_load_ps(params->sse${SSE}.scale);
36  const __m128 voutput_max_less_zero_point = _mm_load_ps(params->sse${SSE}.output_max_less_zero_point);
37  const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse${SSE}.output_zero_point);
38  const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse${SSE}.output_min);
39
40  $if BATCH_TILE > 8:
41    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
42      __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
43      $for N in range(4, BATCH_TILE, 4):
44        __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
45      x += ${BATCH_TILE};
46
47      $for N in range(0, BATCH_TILE, 4):
48        vx${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vscale);
49
50      $for N in range(0, BATCH_TILE, 4):
51        vx${ABC[N:N+4]} = _mm_min_ps(vx${ABC[N:N+4]}, voutput_max_less_zero_point);
52
53      $for N in range(0, BATCH_TILE, 4):
54        const __m128i vy${ABC[N:N+4]} = _mm_cvtps_epi32(vx${ABC[N:N+4]});
55
56      $for N in range(0, BATCH_TILE, 8):
57        __m128i vy${ABC[N:N+8]} = _mm_packs_epi32(vy${ABC[N:N+4]}, vy${ABC[N+4:N+8]});
58
59      $for N in range(0, BATCH_TILE, 8):
60        vy${ABC[N:N+8]} = _mm_adds_epi16(vy${ABC[N:N+8]}, voutput_zero_point);
61
62      $if DATATYPE == "QS8" and SSE < 4:
63        $for N in range(0, BATCH_TILE, 8):
64          vy${ABC[N:N+8]} = _mm_max_epi16(vy${ABC[N:N+8]}, voutput_min);
65
66      $for N in range(0, BATCH_TILE, 16):
67        $if N + 8 < BATCH_TILE:
68          __m128i vy${ABC[N:N+16]} = ${_MM_PACKXS_EPI16}(vy${ABC[N:N+8]}, vy${ABC[N+8:N+16]});
69        $else:
70          vy${ABC[N:N+8]} = ${_MM_PACKXS_EPI16}(vy${ABC[N:N+8]}, vy${ABC[N:N+8]});
71
72      $if DATATYPE == "QU8" or SSE == 4:
73        $for N in range(0, BATCH_TILE, 16):
74          $if N + 8 < BATCH_TILE:
75            vy${ABC[N:N+16]} = ${_MM_MAX_EPX8}(vy${ABC[N:N+16]}, voutput_min);
76          $else:
77            vy${ABC[N:N+8]} = ${_MM_MAX_EPX8}(vy${ABC[N:N+8]}, voutput_min);
78
79      _mm_storeu_si128((__m128i*) y, vy${ABC[0:16]});
80      $for N in range(16, BATCH_TILE, 16):
81        $if N + 8 < BATCH_TILE:
82          _mm_storeu_si128((__m128i*) (y + ${N}), vy${ABC[N:N+16]});
83        $else:
84          _mm_storel_epi64((__m128i*) (y + ${N}), vy${ABC[N:N+8]});
85      y += ${BATCH_TILE};
86    }
87  for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
88    __m128 vx_lo = _mm_loadu_ps(x);
89    __m128 vx_hi = _mm_loadu_ps(x + 4);
90    x += 8;
91
92    vx_lo = _mm_mul_ps(vx_lo, vscale);
93    vx_hi = _mm_mul_ps(vx_hi, vscale);
94
95    vx_lo = _mm_min_ps(vx_lo, voutput_max_less_zero_point);
96    vx_hi = _mm_min_ps(vx_hi, voutput_max_less_zero_point);
97
98    const __m128i vy_lo = _mm_cvtps_epi32(vx_lo);
99    const __m128i vy_hi = _mm_cvtps_epi32(vx_hi);
100
101    __m128i vy = _mm_packs_epi32(vy_lo, vy_hi);
102    vy = _mm_adds_epi16(vy, voutput_zero_point);
103    $if DATATYPE == "QS8" and SSE < 4:
104      vy = _mm_max_epi16(vy, voutput_min);
105    vy = ${_MM_PACKXS_EPI16}(vy, vy);
106    $if DATATYPE == "QU8" or SSE == 4:
107      vy = ${_MM_MAX_EPX8}(vy, voutput_min);
108
109    _mm_storel_epi64((__m128i*) y, vy);
110    y += 8;
111  }
112  if XNN_UNLIKELY(n != 0) {
113    __m128 vx_lo = _mm_loadu_ps(x);
114    const float* x_hi = (const float*) ((uintptr_t) x + (n & (4 * sizeof(float))));
115    __m128 vx_hi = _mm_loadu_ps(x_hi);
116
117    vx_lo = _mm_mul_ps(vx_lo, vscale);
118    vx_hi = _mm_mul_ps(vx_hi, vscale);
119
120    vx_lo = _mm_min_ps(vx_lo, voutput_max_less_zero_point);
121    vx_hi = _mm_min_ps(vx_hi, voutput_max_less_zero_point);
122
123    const __m128i vy_lo = _mm_cvtps_epi32(vx_lo);
124    const __m128i vy_hi = _mm_cvtps_epi32(vx_hi);
125
126    __m128i vy = _mm_packs_epi32(vy_lo, vy_hi);
127    vy = _mm_adds_epi16(vy, voutput_zero_point);
128    $if DATATYPE == "QS8" and SSE < 4:
129      vy = _mm_max_epi16(vy, voutput_min);
130    vy = ${_MM_PACKXS_EPI16}(vy, vy);
131    $if DATATYPE == "QU8" or SSE == 4:
132      vy = ${_MM_MAX_EPX8}(vy, voutput_min);
133
134    if (n & (4 * sizeof(float))) {
135      *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vy);
136      y += 4;
137      vy = _mm_srli_epi64(vy, 32);
138    }
139    $if SSE == 4:
140      if (n & (2 * sizeof(float))) {
141        *((uint16_t*) y) = (uint16_t) _mm_extract_epi16(vy, 0);
142        y += 2;
143        vy = _mm_srli_epi32(vy, 16);
144      }
145      if (n & (1 * sizeof(float))) {
146        *y = (${XINT8_T}) _mm_extract_epi8(vy, 0);
147      }
148    $else:
149      {
150        uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy);
151        if (n & (2 * sizeof(float))) {
152          *((uint16_t*) y) = (uint16_t) vy_lo;
153          y += 2;
154          vy_lo >>= 16;
155        }
156        if (n & (1 * sizeof(float))) {
157          *y = (${XINT8_T}) vy_lo;
158        }
159      }
160  }
161}
162