1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] 10#include <assert.h> 11 12#include <${SSE_HEADER}> 13 14#include <xnnpack/common.h> 15#include <xnnpack/vunary.h> 16 17 18$ISA = {2: "sse2", 4: "sse41"}[SSE] 19void xnn_f32_sigmoid_ukernel__${ISA}_p5_div_x${BATCH_TILE}( 20 size_t n, 21 const float* x, 22 float* y, 23 const void* params) XNN_DISABLE_TSAN 24{ 25 assert(n % sizeof(float) == 0); 26 27 const __m128 vsign_mask = _mm_set1_ps(-0.0f); 28 const __m128 vmagic_bias = _mm_set1_ps(0x1.8000FEp23f); 29 const __m128 vlog2e = _mm_set1_ps(0x1.715476p0f); 30 const __m128 vminus_ln2_hi = _mm_set1_ps(-0x1.62E400p-1f); 31 const __m128 vminus_ln2_lo = _mm_set1_ps(-0x1.7F7D1Cp-20f); 32 const __m128 vc5 = _mm_set1_ps(0x1.0F9F9Cp-7f); 33 const __m128 vc4 = _mm_set1_ps(0x1.573A1Ap-5f); 34 const __m128 vc3 = _mm_set1_ps(0x1.555A80p-3f); 35 const __m128 vc2 = _mm_set1_ps(0x1.FFFDC6p-2f); 36 const __m128 vc1 = _mm_set1_ps(0x1.FFFFF6p-1f); 37 const __m128 vone = _mm_set1_ps(1.0f); 38 const __m128 vdenorm_cutoff = _mm_set1_ps(-0x1.5D589Ep+6f); 39 40 $if BATCH_TILE > 4: 41 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 42 const __m128 vx${ABC[0:4]} = _mm_loadu_ps(x); 43 $for N in range(4, BATCH_TILE, 4): 44 const __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N}); 45 46 $for N in range(0, BATCH_TILE, 4): 47 const __m128 vz${ABC[N:N+4]} = _mm_or_ps(vx${ABC[N:N+4]}, vsign_mask); 48 49 $for N in range(0, BATCH_TILE, 4): 50 __m128 vn${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vz${ABC[N:N+4]}, vlog2e), vmagic_bias); 51 52 $for N in range(0, BATCH_TILE, 4): 53 const __m128 vs${ABC[N:N+4]} = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn${ABC[N:N+4]}), 23)); 54 55 $for N in range(0, BATCH_TILE, 4): 56 vn${ABC[N:N+4]} = _mm_sub_ps(vn${ABC[N:N+4]}, vmagic_bias); 57 58 $for N in range(0, BATCH_TILE, 4): 59 __m128 vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_hi), vz${ABC[N:N+4]}); 60 61 $for N in range(0, BATCH_TILE, 4): 62 vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_lo), vt${ABC[N:N+4]}); 63 64 $for N in range(0, BATCH_TILE, 4): 65 __m128 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vc5, vt${ABC[N:N+4]}), vc4); 66 67 $for N in range(0, BATCH_TILE, 4): 68 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc3); 69 70 $for N in range(0, BATCH_TILE, 4): 71 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc2); 72 73 $for N in range(0, BATCH_TILE, 4): 74 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc1); 75 76 $for N in range(0, BATCH_TILE, 4): 77 vt${ABC[N:N+4]} = _mm_mul_ps(vt${ABC[N:N+4]}, vs${ABC[N:N+4]}); 78 79 $for N in range(0, BATCH_TILE, 4): 80 __m128 ve${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vt${ABC[N:N+4]}, vp${ABC[N:N+4]}), vs${ABC[N:N+4]}); 81 82 $for N in range(0, BATCH_TILE, 4): 83 __m128 vd${ABC[N:N+4]} = _mm_add_ps(ve${ABC[N:N+4]}, vone); 84 85 $for N in range(0, BATCH_TILE, 4): 86 __m128 vf${ABC[N:N+4]} = _mm_div_ps(ve${ABC[N:N+4]}, vd${ABC[N:N+4]}); 87 88 $for N in range(0, BATCH_TILE, 4): 89 vf${ABC[N:N+4]} = _mm_andnot_ps(_mm_cmplt_ps(vz${ABC[N:N+4]}, vdenorm_cutoff), vf${ABC[N:N+4]}); 90 91 $if SSE >= 4: 92 $for N in range(0, BATCH_TILE, 4): 93 vf${ABC[N:N+4]} = _mm_blendv_ps(_mm_sub_ps(vone, vf${ABC[N:N+4]}), vf${ABC[N:N+4]}, vx${ABC[N:N+4]}); 94 $else: 95 $for N in range(0, BATCH_TILE, 4): 96 const __m128 vm${ABC[N:N+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx${ABC[N:N+4]}))); 97 98 $for N in range(0, BATCH_TILE, 4): 99 vf${ABC[N:N+4]} = _mm_or_ps(_mm_and_ps(vf${ABC[N:N+4]}, vm${ABC[N:N+4]}), _mm_andnot_ps(vm${ABC[N:N+4]}, _mm_sub_ps(vone, vf${ABC[N:N+4]}))); 100 101 _mm_storeu_ps(y, vf${ABC[0:4]}); 102 $for N in range(4, BATCH_TILE, 4): 103 _mm_storeu_ps(y + ${N}, vf${ABC[N:N+4]}); 104 105 x += ${BATCH_TILE}; 106 y += ${BATCH_TILE}; 107 } 108 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 109 const __m128 vx = _mm_loadu_ps(x); 110 111 const __m128 vz = _mm_or_ps(vx, vsign_mask); 112 113 __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); 114 const __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23)); 115 vn = _mm_sub_ps(vn, vmagic_bias); 116 117 __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz); 118 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt); 119 120 __m128 vp = _mm_add_ps(_mm_mul_ps(vc5, vt), vc4); 121 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3); 122 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); 123 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc1); 124 125 vt = _mm_mul_ps(vt, vs); 126 __m128 ve = _mm_add_ps(_mm_mul_ps(vt, vp), vs); 127 128 __m128 vd = _mm_add_ps(ve, vone); 129 __m128 vf = _mm_div_ps(ve, vd); 130 131 vf = _mm_andnot_ps(_mm_cmplt_ps(vz, vdenorm_cutoff), vf); 132 $if SSE >= 4: 133 vf = _mm_blendv_ps(_mm_sub_ps(vone, vf), vf, vx); 134 $else: 135 const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx))); 136 vf = _mm_or_ps(_mm_and_ps(vf, vm), _mm_andnot_ps(vm, _mm_sub_ps(vone, vf))); 137 138 _mm_storeu_ps(y, vf); 139 140 x += 4; 141 y += 4; 142 } 143 if XNN_UNLIKELY(n != 0) { 144 const __m128 vx = _mm_loadu_ps(x); 145 146 const __m128 vz = _mm_or_ps(vx, vsign_mask); 147 148 __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); 149 const __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23)); 150 vn = _mm_sub_ps(vn, vmagic_bias); 151 152 __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz); 153 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt); 154 155 __m128 vp = _mm_add_ps(_mm_mul_ps(vc5, vt), vc4); 156 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3); 157 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); 158 vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc1); 159 160 vt = _mm_mul_ps(vt, vs); 161 __m128 ve = _mm_add_ps(_mm_mul_ps(vt, vp), vs); 162 163 __m128 vd = _mm_add_ps(ve, vone); 164 __m128 vf = _mm_div_ps(ve, vd); 165 166 vf = _mm_andnot_ps(_mm_cmplt_ps(vz, vdenorm_cutoff), vf); 167 $if SSE >= 4: 168 vf = _mm_blendv_ps(_mm_sub_ps(vone, vf), vf, vx); 169 $else: 170 const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx))); 171 vf = _mm_or_ps(_mm_and_ps(vf, vm), _mm_andnot_ps(vm, _mm_sub_ps(vone, vf))); 172 173 if (n & (2 * sizeof(float))) { 174 _mm_storel_pi((__m64*) y, vf); 175 vf = _mm_movehl_ps(vf, vf); 176 y += 2; 177 } 178 if (n & (1 * sizeof(float))) { 179 _mm_store_ss(y, vf); 180 } 181 } 182} 183