1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE >= 1 7#include <assert.h> 8#include <math.h> 9 10#include <xnnpack/common.h> 11#include <xnnpack/math.h> 12#include <xnnpack/vunary.h> 13 14 15extern XNN_INTERNAL const uint32_t xnn_table_exp2minus_k_over_16[16]; 16 17void xnn_f32_velu_ukernel__${"wasm" if WASM else "scalar"}_rr2_lut16_p3_x${BATCH_TILE}( 18 size_t n, 19 const float* x, 20 float* y, 21 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) 22{ 23 assert(n % sizeof(float) == 0); 24 25 const float vprescale = params->scalar_rr2_lut16_p3.prescale; 26 const float valpha = params->scalar_rr2_lut16_p3.alpha; 27 const float vbeta = params->scalar_rr2_lut16_p3.beta; 28 const float vmagic_bias = params->scalar_rr2_lut16_p3.magic_bias; 29 const float vlog2e = params->scalar_rr2_lut16_p3.log2e; 30 const uint32_t vindex_mask = UINT32_C(0xF); 31 const float vsat_cutoff = params->scalar_rr2_lut16_p3.sat_cutoff; 32 const float vminus_ln2_hi = params->scalar_rr2_lut16_p3.minus_ln2_hi; 33 const float vminus_ln2_lo = params->scalar_rr2_lut16_p3.minus_ln2_lo; 34 const float vc3 = params->scalar_rr2_lut16_p3.c3; 35 const float vc2 = params->scalar_rr2_lut16_p3.c2; 36 const float vone = params->scalar_rr2_lut16_p3.one; 37 38 $if BATCH_TILE > 1: 39 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 40 $for N in range(BATCH_TILE): 41 float vx${N} = x[${N}]; 42 x += ${BATCH_TILE}; 43 44 $for N in range(BATCH_TILE): 45 $if WASM: 46 const float vz${N} = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx${N} * vprescale, vsat_cutoff), 0.0f); 47 $else: 48 const float vz${N} = vx${N} * vprescale; 49 50 $for N in range(BATCH_TILE): 51 float vn${N} = vz${N} * vlog2e + vmagic_bias; 52 53 $for N in range(BATCH_TILE): 54 const uint32_t ven${N} = float_as_uint32(vn${N}) << 19; 55 const uint32_t vidx${N} = float_as_uint32(vn${N}) & vindex_mask; 56 vn${N} -= vmagic_bias; 57 58 $for N in range(BATCH_TILE): 59 float vt${N} = vn${N} * vminus_ln2_hi + vz${N}; 60 float vs${N} = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx${N}] + ven${N}); 61 62 $for N in range(BATCH_TILE): 63 vt${N} = vn${N} * vminus_ln2_lo + vt${N}; 64 $if not WASM: 65 if XNN_UNPREDICTABLE(vz${N} <= vsat_cutoff) { 66 vs${N} = 0.0f; 67 vt${N} = 0.0f; 68 } 69 70 $for N in range(BATCH_TILE): 71 float vp${N} = vc3 * vt${N} + vc2; 72 73 $for N in range(BATCH_TILE): 74 vp${N} *= vt${N}; 75 76 $for N in range(BATCH_TILE): 77 vt${N} *= vs${N}; 78 vs${N} -= vone; 79 80 $for N in range(BATCH_TILE): 81 vp${N} = vp${N} * vt${N} + vt${N}; 82 83 $for N in range(BATCH_TILE): 84 const float ve${N} = (vp${N} + vs${N}) * valpha; 85 $if WASM: 86 float vy${N} = __builtin_wasm_max_f32(vx${N} * vbeta, 0.0f); 87 $else: 88 float vy${N} = vx${N} * vbeta; 89 90 $if WASM: 91 $for N in range(BATCH_TILE): 92 vy${N} += __builtin_wasm_min_f32(ve${N}, 0.0f); 93 $else: 94 $for N in range(BATCH_TILE): 95 if XNN_UNPREDICTABLE(vx${N} < 0.0f) { 96 vy${N} = ve${N}; 97 } 98 99 $for N in range(BATCH_TILE): 100 y[${N}] = vy${N}; 101 y += ${BATCH_TILE}; 102 } 103 $if BATCH_TILE == 1: 104 do { 105 float vx = *x++; 106 107 $if WASM: 108 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 109 $else: 110 const float vz = vx * vprescale; 111 112 float vn = vz * vlog2e + vmagic_bias; 113 const uint32_t ven = float_as_uint32(vn) << 19; 114 const uint32_t vidx = float_as_uint32(vn) & vindex_mask; 115 vn -= vmagic_bias; 116 117 float vt = vn * vminus_ln2_hi + vz; 118 float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven); 119 120 vt = vn * vminus_ln2_lo + vt; 121 $if not WASM: 122 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 123 vs = 0.0f; 124 vt = 0.0f; 125 } 126 127 float vp = vc3 * vt + vc2; 128 vp *= vt; 129 130 vt *= vs; 131 vs -= vone; 132 vp = vp * vt + vt; 133 const float ve = (vp + vs) * valpha; 134 135 $if WASM: 136 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 137 vy += __builtin_wasm_min_f32(ve, 0.0f); 138 $else: 139 float vy = vx * vbeta; 140 if XNN_UNPREDICTABLE(vx < 0.0f) { 141 vy = ve; 142 } 143 144 *y++ = vy; 145 146 n -= sizeof(float); 147 } while (n != 0); 148 $elif BATCH_TILE == 2: 149 if XNN_UNLIKELY(n != 0) { 150 float vx = *x; 151 152 $if WASM: 153 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 154 $else: 155 const float vz = vx * vprescale; 156 157 float vn = vz * vlog2e + vmagic_bias; 158 const uint32_t ven = float_as_uint32(vn) << 19; 159 const uint32_t vidx = float_as_uint32(vn) & vindex_mask; 160 vn -= vmagic_bias; 161 162 float vt = vn * vminus_ln2_hi + vz; 163 float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven); 164 165 vt = vn * vminus_ln2_lo + vt; 166 $if not WASM: 167 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 168 vs = 0.0f; 169 vt = 0.0f; 170 } 171 172 float vp = vc3 * vt + vc2; 173 vp *= vt; 174 175 vt *= vs; 176 vs -= vone; 177 vp = vp * vt + vt; 178 const float ve = (vp + vs) * valpha; 179 180 $if WASM: 181 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 182 vy += __builtin_wasm_min_f32(ve, 0.0f); 183 $else: 184 float vy = vx * vbeta; 185 if XNN_UNPREDICTABLE(vx < 0.0f) { 186 vy = ve; 187 } 188 189 *y = vy; 190 } 191 $else: 192 if XNN_UNLIKELY(n != 0) { 193 do { 194 float vx = *x++; 195 196 $if WASM: 197 const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f); 198 $else: 199 const float vz = vx * vprescale; 200 201 float vn = vz * vlog2e + vmagic_bias; 202 const uint32_t ven = float_as_uint32(vn) << 19; 203 const uint32_t vidx = float_as_uint32(vn) & vindex_mask; 204 vn -= vmagic_bias; 205 206 float vt = vn * vminus_ln2_hi + vz; 207 float vs = uint32_as_float(xnn_table_exp2minus_k_over_16[vidx] + ven); 208 209 vt = vn * vminus_ln2_lo + vt; 210 $if not WASM: 211 if XNN_UNPREDICTABLE(vz <= vsat_cutoff) { 212 vs = 0.0f; 213 vt = 0.0f; 214 } 215 216 float vp = vc3 * vt + vc2; 217 vp *= vt; 218 219 vt *= vs; 220 vs -= vone; 221 vp = vp * vt + vt; 222 const float ve = (vp + vs) * valpha; 223 224 $if WASM: 225 float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f); 226 vy += __builtin_wasm_min_f32(ve, 0.0f); 227 $else: 228 float vy = vx * vbeta; 229 if XNN_UNPREDICTABLE(vx < 0.0f) { 230 vy = ve; 231 } 232 233 *y++ = vy; 234 235 n -= sizeof(float); 236 } while (n != 0); 237 } 238} 239