1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$assert RR_STEPS in [1, 2] 9$assert DIV_ALGO in ["div", "nr2"] 10$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 11$SIMD_TILE = BATCH_TILE // 8 12#include <assert.h> 13 14#include <immintrin.h> 15 16#include <xnnpack/common.h> 17#include <xnnpack/vunary.h> 18 19 20static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0}; 21 22void xnn_f32_sigmoid_ukernel__avx_rr${RR_STEPS}_p5_${DIV_ALGO}_x${BATCH_TILE}( 23 size_t n, 24 const float* x, 25 float* y, 26 const void* params) 27{ 28 assert(n % sizeof(float) == 0); 29 30 const __m256 vsign_mask = _mm256_set1_ps(-0.0f); 31 const __m256 vmagic_bias = _mm256_set1_ps(0x1.8000FEp23f); 32 const __m256 vlog2e = _mm256_set1_ps(0x1.715476p0f); 33 $if RR_STEPS == 1: 34 const __m256 vminus_ln2 = _mm256_set1_ps(-0x1.62E43p-1f); 35 $else: 36 const __m256 vminus_ln2_hi = _mm256_set1_ps(-0x1.62E43p-1f); 37 const __m256 vminus_ln2_lo = _mm256_set1_ps(0x1.05C61p-29f); 38 const __m256 vc5 = _mm256_set1_ps(0x1.0F9F9Cp-7f); 39 const __m256 vc4 = _mm256_set1_ps(0x1.573A1Ap-5f); 40 const __m256 vc3 = _mm256_set1_ps(0x1.555A80p-3f); 41 const __m256 vc2 = _mm256_set1_ps(0x1.FFFDC6p-2f); 42 const __m256 vc1 = _mm256_set1_ps(0x1.FFFFF6p-1f); 43 const __m256 vone = _mm256_set1_ps(1.0f); 44 $if DIV_ALGO == "nr2": 45 const __m256 vtwo = _mm256_set1_ps(2.0f); 46 const __m256 vdenorm_cutoff = _mm256_set1_ps(-0x1.5D589Ep+6f); 47 48 $if BATCH_TILE > 8: 49 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 50 const __m256 vx${ABC[0]} = _mm256_loadu_ps(x); 51 $for N in range(1, SIMD_TILE): 52 const __m256 vx${ABC[N]} = _mm256_loadu_ps(x + ${N * 8}); 53 x += ${BATCH_TILE}; 54 55 $for N in range(SIMD_TILE): 56 const __m256 vz${ABC[N]} = _mm256_or_ps(vx${ABC[N]}, vsign_mask); 57 58 $for N in range(SIMD_TILE): 59 __m256 vn${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vz${ABC[N]}, vlog2e), vmagic_bias); 60 61 $for N in range(SIMD_TILE): 62 const __m128 vs_lo${ABC[N]} = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn${ABC[N]})), 23)); 63 const __m128 vs_hi${ABC[N]} = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn${ABC[N]}, 1)), 23)); 64 const __m256 vs${ABC[N]} = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo${ABC[N]}), vs_hi${ABC[N]}, 1); 65 66 $for N in range(SIMD_TILE): 67 vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias); 68 69 $if RR_STEPS == 1: 70 $for N in range(SIMD_TILE): 71 __m256 vt${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vn${ABC[N]}, vminus_ln2), vz${ABC[N]}); 72 $else: 73 $for N in range(SIMD_TILE): 74 __m256 vt${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vn${ABC[N]}, vminus_ln2_hi), vz${ABC[N]}); 75 76 $for N in range(SIMD_TILE): 77 vt${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vn${ABC[N]}, vminus_ln2_lo), vt${ABC[N]}); 78 79 $for N in range(SIMD_TILE): 80 __m256 vp${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vc5, vt${ABC[N]}), vc4); 81 82 $for N in range(SIMD_TILE): 83 vp${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vp${ABC[N]}, vt${ABC[N]}), vc3); 84 85 $for N in range(SIMD_TILE): 86 vp${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vp${ABC[N]}, vt${ABC[N]}), vc2); 87 88 $for N in range(SIMD_TILE): 89 vp${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vp${ABC[N]}, vt${ABC[N]}), vc1); 90 91 $for N in range(SIMD_TILE): 92 vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]}); 93 94 $for N in range(SIMD_TILE): 95 const __m256 ve${ABC[N]} = _mm256_add_ps(_mm256_mul_ps(vt${ABC[N]}, vp${ABC[N]}), vs${ABC[N]}); 96 97 $for N in range(SIMD_TILE): 98 const __m256 vd${ABC[N]} = _mm256_add_ps(ve${ABC[N]}, vone); 99 100 $if DIV_ALGO == "div": 101 $for N in range(SIMD_TILE): 102 __m256 vf${ABC[N]} = _mm256_div_ps(ve${ABC[N]}, vd${ABC[N]}); 103 $else: 104 $for N in range(SIMD_TILE): 105 __m256 vr${ABC[N]} = _mm256_rcp_ps(vd${ABC[N]}); 106 107 $for N in range(SIMD_TILE): 108 vr${ABC[N]} = _mm256_mul_ps(vr${ABC[N]}, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr${ABC[N]}, vd${ABC[N]}))); 109 vr${ABC[N]} = _mm256_mul_ps(vr${ABC[N]}, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr${ABC[N]}, vd${ABC[N]}))); 110 111 $for N in range(SIMD_TILE): 112 __m256 vf${ABC[N]} = _mm256_mul_ps(ve${ABC[N]}, vr${ABC[N]}); 113 114 $for N in range(SIMD_TILE): 115 vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vz${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]}); 116 117 $for N in range(SIMD_TILE): 118 vf${ABC[N]} = _mm256_blendv_ps(_mm256_sub_ps(vone, vf${ABC[N]}), vf${ABC[N]}, vx${ABC[N]}); 119 120 _mm256_storeu_ps(y, vf${ABC[0]}); 121 $for N in range(1, SIMD_TILE): 122 _mm256_storeu_ps(y + ${N * 8}, vf${ABC[N]}); 123 y += ${BATCH_TILE}; 124 } 125 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 126 const __m256 vx = _mm256_loadu_ps(x); 127 x += 8; 128 129 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 130 131 __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias); 132 133 const __m128 vs_lo = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)); 134 const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn, 1)), 23)); 135 const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); 136 137 vn = _mm256_sub_ps(vn, vmagic_bias); 138 139 $if RR_STEPS == 1: 140 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz); 141 $else: 142 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_hi), vz); 143 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_lo), vt); 144 145 __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc5, vt), vc4); 146 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc3); 147 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc2); 148 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc1); 149 150 vt = _mm256_mul_ps(vt, vs); 151 const __m256 ve = _mm256_add_ps(_mm256_mul_ps(vt, vp), vs); 152 153 const __m256 vd = _mm256_add_ps(ve, vone); 154 $if DIV_ALGO == "div": 155 __m256 vf = _mm256_div_ps(ve, vd); 156 $else: 157 __m256 vr = _mm256_rcp_ps(vd); 158 vr = _mm256_mul_ps(vr, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr, vd))); 159 vr = _mm256_mul_ps(vr, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr, vd))); 160 __m256 vf = _mm256_mul_ps(ve, vr); 161 162 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 163 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 164 165 _mm256_storeu_ps(y, vf); 166 y += 8; 167 } 168 if XNN_UNLIKELY(n != 0) { 169 assert(n >= 1 * sizeof(float)); 170 assert(n <= 7 * sizeof(float)); 171 __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n)); 172 173 const __m256 vx = _mm256_maskload_ps(x, vmask); 174 175 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 176 177 __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias); 178 const __m128 vs_lo = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)); 179 const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_extractf128_ps(vn, 1)), 23)); 180 const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); 181 182 vn = _mm256_sub_ps(vn, vmagic_bias); 183 184 $if RR_STEPS == 1: 185 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz); 186 $else: 187 __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_hi), vz); 188 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2_lo), vt); 189 190 __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc5, vt), vc4); 191 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc3); 192 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc2); 193 vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc1); 194 195 vt = _mm256_mul_ps(vt, vs); 196 const __m256 ve = _mm256_add_ps(_mm256_mul_ps(vt, vp), vs); 197 198 const __m256 vd = _mm256_add_ps(ve, vone); 199 $if DIV_ALGO == "div": 200 __m256 vf = _mm256_div_ps(ve, vd); 201 $else: 202 __m256 vr = _mm256_rcp_ps(vd); 203 vr = _mm256_mul_ps(vr, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr, vd))); 204 vr = _mm256_mul_ps(vr, _mm256_sub_ps(vtwo, _mm256_mul_ps(vr, vd))); 205 __m256 vf = _mm256_mul_ps(ve, vr); 206 207 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 208 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 209 210 // _mm256_maskstore_ps(y, vmask, vf) could be used here, but triggers msan failures (probably an msan bug). 211 __m128 vf_lo = _mm256_castps256_ps128(vf); 212 if (n & (4 * sizeof(float))) { 213 _mm_storeu_ps(y, vf_lo); 214 vf_lo = _mm256_extractf128_ps(vf, 1); 215 y += 4; 216 } 217 if (n & (2 * sizeof(float))) { 218 _mm_storel_pi((__m64*) y, vf_lo); 219 vf_lo = _mm_movehl_ps(vf_lo, vf_lo); 220 y += 2; 221 } 222 if (n & (1 * sizeof(float))) { 223 _mm_store_ss(y, vf_lo); 224 } 225 } 226} 227