1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE >= 1 7#include <assert.h> 8#include <math.h> 9 10#include <xnnpack/common.h> 11#include <xnnpack/vunary.h> 12 13#include <fp16/bitcasts.h> 14 15 16void xnn_f32_velu_ukernel__${"wasm" if WASM else "scalar"}_rr2_p6_x${BATCH_TILE}( 17 size_t n, 18 const float* x, 19 float* y, 20 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) 21{ 22 assert(n % sizeof(float) == 0); 23 24 const float vprescale = params->scalar.prescale; 25 const float valpha = params->scalar.alpha; 26 const float vbeta = params->scalar.beta; 27 28 const float vmagic_bias = 0x1.8000FEp23f; 29 const float vlog2e = 0x1.715476p+0f; 30 const float vsat_cutoff = -0x1.154246p+4f; 31 const float vminus_ln2_hi = -0x1.62E440p-1f; 32 const float vminus_ln2_lo = 0x1.0105C6p-21f; 33 const float vc6 = 0x1.6b7338p-10f; 34 const float vc5 = 0x1.12278Ep-7f; 35 const float vc4 = 0x1.555716p-5f; 36 const float vc3 = 0x1.5554B0p-3f; 37 const float vc2 = 0x1.FFFFFEp-2f; 38 const float vone = 1.0f; 39 40 $if BATCH_TILE > 1: 41 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 42 $for N in range(BATCH_TILE): 43 float vx${N} = x[${N}]; 44 x += ${BATCH_TILE}; 45 46 $for N in range(BATCH_TILE): 47 $if WASM: 48 const float vz${N} = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx${N} * vprescale, vsat_cutoff), 0.0f); 49 $else: 50 const float vz${N} = vx${N} * vprescale; 51 52 $for N in range(BATCH_TILE): 53 float vn${N} = vz${N} * vlog2e + vmagic_bias; 54 55 $for N in range(BATCH_TILE): 56 float vs${N} = fp32_from_bits(fp32_to_bits(vn${N}) << 23); 57 vn${N} -= vmagic_bias; 58 59 $for N in range(BATCH_TILE): 60 float vt${N} = vn${N} * vminus_ln2_hi + vz${N}; 61 62 $for N in range(BATCH_TILE): 63 vt${N} = vn${N} * vminus_ln2_lo + vt${N}; 64 65 $if not WASM: 66 $for N in range(BATCH_TILE): 67 if XNN_UNPREDICTABLE(vz${N} <= vsat_cutoff) { 68 vs${N} = 0.0f; 69 vt${N} = 0.0f; 70 } 71 72 $for N in range(BATCH_TILE): 73 float vp${N} = vc6 * vt${N} + vc5; 74 75 $for N in range(BATCH_TILE): 76 vp${N} = vp${N} * vt${N} + vc4; 77 78 $for N in range(BATCH_TILE): 79 vp${N} = vp${N} * vt${N} + vc3; 80 81 $for N in range(BATCH_TILE): 82 vp${N} = vp${N} * vt${N} + vc2; 83 84 $for N in range(BATCH_TILE): 85 vp${N} *= vt${N}; 86 87 $for N in range(BATCH_TILE): 88 vt${N} *= vs${N}; 89 vs${N} -= vone; 90 91 $for N in range(BATCH_TILE): 92 vp${N} = vp${N} * vt${N} + vt${N}; 93 94 $for N in range(BATCH_TILE): 95 const float ve${N} = (vp${N} + vs${N}) * valpha; 96 $if WASM: 97 float vy${N} = __builtin_wasm_max_f32(vx${N} * vbeta, 0.0f); 98 $else: 99 float vy${N} = vx${N} * vbeta; 100 101 $if WASM: 102 $for N in range(BATCH_TILE): 103 vy${N} += __builtin_wasm_min_f32(ve${N}, 0.0f); 104 $else: 105 $for N in range(BATCH_TILE): 106 if XNN_UNPREDICTABLE(vx${N} < 0.0f) { 107 vy${N} = ve${N}; 108 } 109 110 $for N in range(BATCH_TILE): 111 y[${N}] = vy${N}; 112 y += ${BATCH_TILE}; 113 } 114 $if BATCH_TILE == 1: 115 do { 116 float vx = *x++; 117 118 $if WASM: 119 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 120 $else: 121 const float vz = vx * vprescale; 122 123 float vn = vz * vlog2e + vmagic_bias; 124 float vs = fp32_from_bits(fp32_to_bits(vn) << 23); 125 vn -= vmagic_bias; 126 127 float vt = vn * vminus_ln2_hi + vz; 128 vt = vn * vminus_ln2_lo + vt; 129 130 $if not WASM: 131 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 132 vs = 0.0f; 133 vt = 0.0f; 134 } 135 136 float vp = vc6 * vt + vc5; 137 vp = vp * vt + vc4; 138 vp = vp * vt + vc3; 139 vp = vp * vt + vc2; 140 vp *= vt; 141 142 vt *= vs; 143 vs -= vone; 144 vp = vp * vt + vt; 145 const float ve = (vp + vs) * valpha; 146 147 $if WASM: 148 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 149 vy += __builtin_wasm_min_f32(ve, 0.0f); 150 $else: 151 float vy = vx * vbeta; 152 if XNN_UNPREDICTABLE(vx < 0.0f) { 153 vy = ve; 154 } 155 156 *y++ = vy; 157 158 n -= sizeof(float); 159 } while (n != 0); 160 $elif BATCH_TILE == 2: 161 if XNN_UNLIKELY(n != 0) { 162 float vx = *x; 163 164 $if WASM: 165 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 166 $else: 167 const float vz = vx * vprescale; 168 169 float vn = vz * vlog2e + vmagic_bias; 170 float vs = fp32_from_bits(fp32_to_bits(vn) << 23); 171 vn -= vmagic_bias; 172 173 float vt = vn * vminus_ln2_hi + vz; 174 vt = vn * vminus_ln2_lo + vt; 175 176 $if not WASM: 177 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 178 vs = 0.0f; 179 vt = 0.0f; 180 } 181 182 float vp = vc6 * vt + vc5; 183 vp = vp * vt + vc4; 184 vp = vp * vt + vc3; 185 vp = vp * vt + vc2; 186 vp *= vt; 187 188 vt *= vs; 189 vs -= vone; 190 vp = vp * vt + vt; 191 const float ve = (vp + vs) * valpha; 192 193 $if WASM: 194 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 195 vy += __builtin_wasm_min_f32(ve, 0.0f); 196 $else: 197 float vy = vx * vbeta; 198 if XNN_UNPREDICTABLE(vx < 0.0f) { 199 vy = ve; 200 } 201 202 *y = vy; 203 } 204 $else: 205 if XNN_UNLIKELY(n != 0) { 206 do { 207 float vx = *x++; 208 209 $if WASM: 210 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 211 $else: 212 const float vz = vx * vprescale; 213 214 float vn = vz * vlog2e + vmagic_bias; 215 float vs = fp32_from_bits(fp32_to_bits(vn) << 23); 216 vn -= vmagic_bias; 217 218 float vt = vn * vminus_ln2_hi + vz; 219 vt = vn * vminus_ln2_lo + vt; 220 221 $if not WASM: 222 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 223 vs = 0.0f; 224 vt = 0.0f; 225 } 226 227 float vp = vc6 * vt + vc5; 228 vp = vp * vt + vc4; 229 vp = vp * vt + vc3; 230 vp = vp * vt + vc2; 231 vp *= vt; 232 233 vt *= vs; 234 vs -= vone; 235 vp = vp * vt + vt; 236 const float ve = (vp + vs) * valpha; 237 238 $if WASM: 239 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 240 vy += __builtin_wasm_min_f32(ve, 0.0f); 241 $else: 242 float vy = vx * vbeta; 243 if XNN_UNPREDICTABLE(vx < 0.0f) { 244 vy = ve; 245 } 246 247 *y++ = vy; 248 249 n -= sizeof(float); 250 } while (n != 0); 251 } 252} 253