1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ABS", "NEG", "SQR"] 10#include <assert.h> 11 12#include <xmmintrin.h> 13 14#include <xnnpack/common.h> 15#include <xnnpack/math.h> 16#include <xnnpack/vunary.h> 17 18 19$_MM_OP_PS = { 20$ "ABS": lambda x: "_mm_and_ps(%s, vnonsign_mask)" % x, 21$ "NEG": lambda x: "_mm_xor_ps(%s, vsign_mask)" % x, 22$ "SQR": lambda x: "_mm_mul_ps(%s, %s)" % (x, x), 23$}[OP] 24$PARAMS = { 25$ "ABS": "const union xnn_f32_abs_params params[restrict XNN_MIN_ELEMENTS(1)]", 26$ "NEG": "const union xnn_f32_neg_params params[restrict XNN_MIN_ELEMENTS(1)]", 27$ "SQR": "const void* params", 28$}[OP] 29void xnn_f32_v${OP.lower()}_ukernel__sse_x${BATCH_TILE}( 30 size_t n, 31 const float* x, 32 float* y, 33 ${PARAMS}) XNN_DISABLE_TSAN 34{ 35 assert(n != 0); 36 assert(n % sizeof(float) == 0); 37 assert(x != NULL); 38 assert(y != NULL); 39 40 $if OP == "ABS": 41 const __m128 vnonsign_mask = _mm_load_ps(params->sse.nonsign_mask); 42 $elif OP == "NEG": 43 const __m128 vsign_mask = _mm_load_ps(params->sse.sign_mask); 44 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 45 const __m128 vx${ABC[0:4]} = _mm_loadu_ps(x); 46 $for N in range(4, BATCH_TILE, 4): 47 const __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N}); 48 x += ${BATCH_TILE}; 49 50 $for N in range(0, BATCH_TILE, 4): 51 const __m128 vy${ABC[N:N+4]} = ${_MM_OP_PS("vx" + ABC[N:N+4])}; 52 53 _mm_storeu_ps(y, vy${ABC[0:4]}); 54 $for N in range(4, BATCH_TILE, 4): 55 _mm_storeu_ps(y + ${N}, vy${ABC[N:N+4]}); 56 y += ${BATCH_TILE}; 57 } 58 $if BATCH_TILE > 4: 59 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 60 const __m128 vx = _mm_loadu_ps(x); 61 x += 4; 62 const __m128 vy = ${_MM_OP_PS("vx")}; 63 _mm_storeu_ps(y, vy); 64 y += 4; 65 } 66 if XNN_UNLIKELY(n != 0) { 67 const __m128 vx = _mm_loadu_ps(x); 68 __m128 vy = ${_MM_OP_PS("vx")}; 69 if (n & (2 * sizeof(float))) { 70 _mm_storel_pi((__m64*) y, vy); 71 vy = _mm_movehl_ps(vy, vy); 72 y += 2; 73 } 74 if (n & (1 * sizeof(float))) { 75 _mm_store_ss(y, vy); 76 } 77 } 78} 79