• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert not AVX or SSE == 4
8$assert BATCH_TILE % 8 == 0
9$assert BATCH_TILE >= 8
10$SIMD_TILE = BATCH_TILE // 8
11$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
12$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
13#include <assert.h>
14
15#include <${SSE_HEADER}>
16
17#include <xnnpack/common.h>
18#include <xnnpack/vcvt.h>
19
20
21$ISA = "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE]
22void xnn_f16_f32_vcvt_ukernel__${ISA}_int16_x${BATCH_TILE}(
23    size_t n,
24    const void* input,
25    float* output,
26    const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27{
28  assert(n != 0);
29  assert(n % sizeof(uint16_t) == 0);
30  assert(input != NULL);
31  assert(output != NULL);
32
33  const __m128i vsign_mask = _mm_load_si128((const __m128i*) params->sse_int16.sign_mask);
34  const __m128i vexp_offset = _mm_load_si128((const __m128i*) params->sse_int16.exp_offset);
35  const __m128 vexp_scale = _mm_load_ps(params->sse_int16.exp_scale);
36  const __m128i vmagic_mask = _mm_load_si128((const __m128i*) params->sse_int16.magic_mask);
37  const __m128 vmagic_bias = _mm_load_ps(params->sse_int16.magic_bias);
38  const __m128i vdenorm_cutoff = _mm_load_si128((const __m128i*) params->sse_int16.denorm_cutoff);
39
40  const uint16_t* i = (const uint16_t*) input;
41  $if BATCH_TILE > 8:
42    for (; n >= ${BATCH_TILE} * sizeof(uint16_t); n -= ${BATCH_TILE} * sizeof(uint16_t)) {
43      const __m128i vh0 = _mm_loadu_si128((const __m128i*) i);
44      $for N in range(1, SIMD_TILE):
45        const __m128i vh${N} = _mm_loadu_si128((const __m128i*) (i + ${N * 8}));
46      i += ${BATCH_TILE};
47
48      $for N in range(SIMD_TILE):
49        const __m128i vsign${N} = _mm_and_si128(vh${N}, vsign_mask);
50
51      $for N in range(SIMD_TILE):
52        const __m128i vnonsign${N} = _mm_xor_si128(vh${N}, vsign${N});
53
54      $for N in range(SIMD_TILE):
55        const __m128i vprenorm${2*N} = _mm_slli_epi16(vnonsign${N}, 13);
56        const __m128i vprenorm${2*N+1} = _mm_add_epi16(_mm_srli_epi16(vnonsign${N}, 3), vexp_offset);
57
58      $for N in range(SIMD_TILE):
59        const __m128i vnorm${2*N} = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vprenorm${2*N}, vprenorm${2*N+1})), vexp_scale));
60        const __m128i vnorm${2*N+1} = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vprenorm${2*N}, vprenorm${2*N+1})), vexp_scale));
61
62      $for N in range(SIMD_TILE):
63        const __m128i vdenorm${2*N} = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vnonsign${N}, vmagic_mask)), vmagic_bias));
64        const __m128i vdenorm${2*N+1} = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vnonsign${N}, vmagic_mask)), vmagic_bias));
65
66      $for N in range(SIMD_TILE):
67        const __m128i vmask${N} = _mm_cmpgt_epi16(vnonsign${N}, vdenorm_cutoff);
68
69      $for N in range(SIMD_TILE):
70        $if SSE == 4:
71          const __m128i vf${2*N} = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign${N}),
72            _mm_blendv_epi8(vdenorm${2*N}, vnorm${2*N}, _mm_cvtepi16_epi32(vmask${N})));
73          const __m128i vf${2*N+1} = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign${N}),
74            _mm_blendv_epi8(vdenorm${2*N+1}, vnorm${2*N+1}, _mm_unpackhi_epi16(vmask${N}, vmask${N})));
75        $else:
76          const __m128i vxmask${2*N} = _mm_unpacklo_epi16(vmask${N}, vmask${N});
77          const __m128i vf${2*N} = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign${N}),
78            _mm_or_si128(_mm_and_si128(vxmask${2*N}, vnorm${2*N}), _mm_andnot_si128(vxmask${2*N}, vdenorm${2*N})));
79          const __m128i vxmask${2*N+1} = _mm_unpackhi_epi16(vmask${N}, vmask${N});
80          const __m128i vf${2*N+1} = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign${N}),
81            _mm_or_si128(_mm_and_si128(vxmask${2*N+1}, vnorm${2*N+1}), _mm_andnot_si128(vxmask${2*N+1}, vdenorm${2*N+1})));
82
83      _mm_storeu_ps(output, _mm_castsi128_ps(vf0));
84      $for N in range(1, 2*SIMD_TILE):
85        _mm_storeu_ps(output + ${N * 4}, _mm_castsi128_ps(vf${N}));
86      output += ${BATCH_TILE};
87    }
88  for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
89    const __m128i vh = _mm_loadu_si128((const __m128i*) i);
90    i += 8;
91
92    const __m128i vsign = _mm_and_si128(vh, vsign_mask);
93
94    const __m128i vnonsign = _mm_xor_si128(vh, vsign);
95
96    const __m128i vprenorm_lo = _mm_slli_epi16(vnonsign, 13);
97    const __m128i vprenorm_hi = _mm_add_epi16(_mm_srli_epi16(vnonsign, 3), vexp_offset);
98
99    const __m128i vnorm_lo = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vprenorm_lo, vprenorm_hi)), vexp_scale));
100    const __m128i vnorm_hi = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vprenorm_lo, vprenorm_hi)), vexp_scale));
101
102    const __m128i vdenorm_lo = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vnonsign, vmagic_mask)), vmagic_bias));
103    const __m128i vdenorm_hi = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vnonsign, vmagic_mask)), vmagic_bias));
104
105    const __m128i vmask = _mm_cmpgt_epi16(vnonsign, vdenorm_cutoff);
106
107    $if SSE == 4:
108      const __m128i vf_lo = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign),
109        _mm_blendv_epi8(vdenorm_lo, vnorm_lo, _mm_cvtepi16_epi32(vmask)));
110    $else:
111      const __m128i vxmask_lo = _mm_unpacklo_epi16(vmask, vmask);
112      const __m128i vf_lo = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign),
113        _mm_or_si128(_mm_and_si128(vxmask_lo, vnorm_lo), _mm_andnot_si128(vxmask_lo, vdenorm_lo)));
114
115    $if SSE == 4:
116      const __m128i vf_hi = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign),
117        _mm_blendv_epi8(vdenorm_hi, vnorm_hi, _mm_unpackhi_epi16(vmask, vmask)));
118    $else:
119      const __m128i vxmask_hi = _mm_unpackhi_epi16(vmask, vmask);
120      const __m128i vf_hi = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign),
121        _mm_or_si128(_mm_and_si128(vxmask_hi, vnorm_hi), _mm_andnot_si128(vxmask_hi, vdenorm_hi)));
122
123    _mm_storeu_ps(output, _mm_castsi128_ps(vf_lo));
124    _mm_storeu_ps(output + 4, _mm_castsi128_ps(vf_hi));
125    output += 8;
126  }
127  if XNN_UNPREDICTABLE(n != 0) {
128    const __m128i vh = _mm_loadu_si128((const __m128i*) i);
129
130    const __m128i vsign = _mm_and_si128(vh, vsign_mask);
131
132    const __m128i vnonsign = _mm_xor_si128(vh, vsign);
133
134    const __m128i vprenorm_lo = _mm_slli_epi16(vnonsign, 13);
135    const __m128i vprenorm_hi = _mm_add_epi16(_mm_srli_epi16(vnonsign, 3), vexp_offset);
136
137    const __m128i vnorm_lo = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vprenorm_lo, vprenorm_hi)), vexp_scale));
138    const __m128i vnorm_hi = _mm_castps_si128(_mm_mul_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vprenorm_lo, vprenorm_hi)), vexp_scale));
139
140    const __m128i vdenorm_lo = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpacklo_epi16(vnonsign, vmagic_mask)), vmagic_bias));
141    const __m128i vdenorm_hi = _mm_castps_si128(_mm_sub_ps(_mm_castsi128_ps(_mm_unpackhi_epi16(vnonsign, vmagic_mask)), vmagic_bias));
142
143    const __m128i vmask = _mm_cmpgt_epi16(vnonsign, vdenorm_cutoff);
144
145    $if SSE == 4:
146      __m128i vf = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign),
147        _mm_blendv_epi8(vdenorm_lo, vnorm_lo, _mm_cvtepi16_epi32(vmask)));
148    $else:
149      const __m128i vxmask_lo = _mm_unpacklo_epi16(vmask, vmask);
150      __m128i vf = _mm_or_si128(_mm_unpacklo_epi16(_mm_setzero_si128(), vsign),
151        _mm_or_si128(_mm_and_si128(vxmask_lo, vnorm_lo), _mm_andnot_si128(vxmask_lo, vdenorm_lo)));
152
153    if (n & (4 * sizeof(uint16_t))) {
154      _mm_storeu_ps(output, _mm_castsi128_ps(vf));
155      output += 4;
156
157      $if SSE == 4:
158        vf = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign),
159          _mm_blendv_epi8(vdenorm_hi, vnorm_hi, _mm_unpackhi_epi16(vmask, vmask)));
160      $else:
161        const __m128i vxmask_hi = _mm_unpackhi_epi16(vmask, vmask);
162        vf = _mm_or_si128(_mm_unpackhi_epi16(_mm_setzero_si128(), vsign),
163          _mm_or_si128(_mm_and_si128(vxmask_hi, vnorm_hi), _mm_andnot_si128(vxmask_hi, vdenorm_hi)));
164    }
165    if (n & (2 * sizeof(uint16_t))) {
166      _mm_storel_pi((__m64*) output, _mm_castsi128_ps(vf));
167      output += 2;
168
169      vf = _mm_castps_si128(_mm_movehl_ps(_mm_castsi128_ps(vf), _mm_castsi128_ps(vf)));
170    }
171    if (n & (1 * sizeof(uint16_t))) {
172      _mm_store_ss(output, _mm_castsi128_ps(vf));
173    }
174  }
175}
176