1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$SIMD_TILE = BATCH_TILE // 8 9#include <assert.h> 10 11#include <immintrin.h> 12 13#include <xnnpack/common.h> 14#include <xnnpack/intrinsics-polyfill.h> 15#include <xnnpack/vunary.h> 16 17 18void xnn_f16_velu_ukernel__avx2_rr1_p3_x${BATCH_TILE}( 19 size_t n, 20 const void* input, 21 void* output, 22 const union xnn_f16_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) 23{ 24 assert(n % sizeof(uint16_t) == 0); 25 26 const __m256 vprescale = _mm256_load_ps(params->avx2_rr1_p3.prescale); 27 const __m256 vsat_cutoff = _mm256_load_ps(params->avx2_rr1_p3.sat_cutoff); 28 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p3.magic_bias); 29 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p3.log2e); 30 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p3.minus_ln2); 31 const __m256 vc3 = _mm256_load_ps(params->avx2_rr1_p3.c3); 32 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p3.c2); 33 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p3.c1); 34 const __m256 valpha = _mm256_load_ps(params->avx2_rr1_p3.alpha); 35 const __m256 vbeta = _mm256_load_ps(params->avx2_rr1_p3.beta); 36 37 const uint16_t* i = (const uint16_t*) input; 38 uint16_t* o = (uint16_t*) output; 39 $if BATCH_TILE > 8: 40 for (; n >= ${BATCH_TILE} * sizeof(uint16_t); n -= ${BATCH_TILE} * sizeof(uint16_t)) { 41 __m256 vx0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 42 $for N in range(1, SIMD_TILE): 43 __m256 vx${N} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + ${N * 8}))); 44 i += ${BATCH_TILE}; 45 46 $for N in range(SIMD_TILE): 47 const __m256 vz${N} = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx${N}, vprescale)); 48 49 $for N in range(SIMD_TILE): 50 __m256 vn${N} = _mm256_fmadd_ps(vz${N}, vlog2e, vmagic_bias); 51 52 $for N in range(SIMD_TILE): 53 __m256 vs${N} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${N}), 23)); 54 vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias); 55 56 $for N in range(SIMD_TILE): 57 __m256 vt${N} = _mm256_fmadd_ps(vn${N}, vminus_ln2, vz${N}); 58 59 $for N in range(SIMD_TILE): 60 __m256 vp${N} = _mm256_fmadd_ps(vc3, vt${N}, vc2); 61 62 $for N in range(SIMD_TILE): 63 vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc1); 64 vt${N} = _mm256_mul_ps(vt${N}, valpha); 65 66 $for N in range(SIMD_TILE): 67 vt${N} = _mm256_mul_ps(vt${N}, vs${N}); 68 vs${N} = _mm256_fmsub_ps(vs${N}, valpha, valpha); 69 70 $for N in range(SIMD_TILE): 71 const __m256 ve${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vs${N}); 72 vx${N} = _mm256_mul_ps(vx${N}, vbeta); 73 74 $for N in range(SIMD_TILE): 75 const __m256 vy${N} = _mm256_blendv_ps(vx${N}, ve${N}, vx${N}); 76 77 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vy0, _MM_FROUND_NO_EXC)); 78 $for N in range(1, SIMD_TILE): 79 _mm_storeu_si128((__m128i*) (o + ${N * 8}), _mm256_cvtps_ph(vy${N}, _MM_FROUND_NO_EXC)); 80 o += ${BATCH_TILE}; 81 } 82 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) { 83 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 84 i += 8; 85 86 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale)); 87 88 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 89 __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 90 vn = _mm256_sub_ps(vn, vmagic_bias); 91 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 92 93 __m256 vp = _mm256_fmadd_ps(vc3, vt, vc2); 94 vp = _mm256_fmadd_ps(vp, vt, vc1); 95 vt = _mm256_mul_ps(vt, valpha); 96 vt = _mm256_mul_ps(vt, vs); 97 vs = _mm256_fmsub_ps(vs, valpha, valpha); 98 const __m256 ve = _mm256_fmadd_ps(vp, vt, vs); 99 vx = _mm256_mul_ps(vx, vbeta); 100 const __m256 vy = _mm256_blendv_ps(vx, ve, vx); 101 102 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC)); 103 o += 8; 104 } 105 if XNN_UNLIKELY(n != 0) { 106 assert(n >= 1 * sizeof(uint16_t)); 107 assert(n <= 7 * sizeof(uint16_t)); 108 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 109 110 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale)); 111 112 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 113 __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 114 vn = _mm256_sub_ps(vn, vmagic_bias); 115 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 116 117 __m256 vp = _mm256_fmadd_ps(vc3, vt, vc2); 118 vp = _mm256_fmadd_ps(vp, vt, vc1); 119 vt = _mm256_mul_ps(vt, valpha); 120 vt = _mm256_mul_ps(vt, vs); 121 vs = _mm256_fmsub_ps(vs, valpha, valpha); 122 const __m256 ve = _mm256_fmadd_ps(vp, vt, vs); 123 vx = _mm256_mul_ps(vx, vbeta); 124 const __m256 vy = _mm256_blendv_ps(vx, ve, vx); 125 126 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC); 127 if (n & (4 * sizeof(uint16_t))) { 128 _mm_storel_epi64((__m128i*) o, vh); 129 vh = _mm_unpackhi_epi64(vh, vh); 130 o += 4; 131 } 132 if (n & (2 * sizeof(uint16_t))) { 133 _mm_storeu_si32(o, vh); 134 vh = _mm_srli_epi64(vh, 32); 135 o += 2; 136 } 137 if (n & (1 * sizeof(uint16_t))) { 138 *o = (uint16_t) _mm_extract_epi16(vh, 0); 139 } 140 } 141} 142