1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB"] 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/common.h> 15#include <xnnpack/vbinary.h> 16 17 18static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0}; 19 20$_MM256_OP_PS = { 21$ "ADD": lambda x, y: "_mm256_add_ps(%s, %s)" % (x, y), 22$ "DIV": lambda x, y: "_mm256_div_ps(%s, %s)" % (x, y), 23$ "MAX": lambda x, y: "_mm256_max_ps(%s, %s)" % (x, y), 24$ "MIN": lambda x, y: "_mm256_min_ps(%s, %s)" % (x, y), 25$ "MUL": lambda x, y: "_mm256_mul_ps(%s, %s)" % (x, y), 26$ "SUB": lambda x, y: "_mm256_sub_ps(%s, %s)" % (x, y), 27$}[OP] 28void xnn_f32_v${OP.lower()}_ukernel__avx_x${BATCH_TILE}( 29 size_t n, 30 const float* a, 31 const float* b, 32 float* y, 33 const union xnn_f32_output_params params[restrict static 1]) 34{ 35 assert(n != 0); 36 assert(n % sizeof(float) == 0); 37 38 const __m256 vy_min = _mm256_broadcast_ps((const __m128*) params->sse.min); 39 const __m256 vy_max = _mm256_broadcast_ps((const __m128*) params->sse.max); 40 41 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 42 const __m256 va${ABC[0:8]} = _mm256_loadu_ps(a); 43 $for N in range(8, BATCH_TILE, 8): 44 const __m256 va${ABC[N:N+8]} = _mm256_loadu_ps(a + ${N}); 45 a += ${BATCH_TILE}; 46 47 const __m256 vb${ABC[0:8]} = _mm256_loadu_ps(b); 48 $for N in range(8, BATCH_TILE, 8): 49 const __m256 vb${ABC[N:N+8]} = _mm256_loadu_ps(b + ${N}); 50 b += ${BATCH_TILE}; 51 52 $for N in range(0, BATCH_TILE, 8): 53 __m256 vy${ABC[N:N+8]} = ${_MM256_OP_PS("va" + ABC[N:N+8], "vb" + ABC[N:N+8])}; 54 55 $for N in range(0, BATCH_TILE, 8): 56 vy${ABC[N:N+8]} = _mm256_max_ps(vy${ABC[N:N+8]}, vy_min); 57 58 $for N in range(0, BATCH_TILE, 8): 59 vy${ABC[N:N+8]} = _mm256_min_ps(vy${ABC[N:N+8]}, vy_max); 60 61 _mm256_storeu_ps(y, vy${ABC[0:8]}); 62 $for N in range(8, BATCH_TILE, 8): 63 _mm256_storeu_ps(y + ${N}, vy${ABC[N:N+8]}); 64 y += ${BATCH_TILE}; 65 } 66 $if BATCH_TILE >= 8: 67 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 68 const __m256 va = _mm256_loadu_ps(a); 69 a += 8; 70 71 const __m256 vb = _mm256_loadu_ps(b); 72 b += 8; 73 74 __m256 vy = ${_MM256_OP_PS("va", "vb")}; 75 vy = _mm256_max_ps(vy, vy_min); 76 vy = _mm256_min_ps(vy, vy_max); 77 _mm256_storeu_ps(y, vy); 78 y += 8; 79 } 80 if XNN_UNLIKELY(n != 0) { 81 assert(n >= 1 * sizeof(float)); 82 assert(n <= 7 * sizeof(float)); 83 __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n)); 84 85 const __m256 va = _mm256_maskload_ps(a, vmask); 86 const __m256 vb = _mm256_maskload_ps(b, vmask); 87 88 __m256 vy = ${_MM256_OP_PS("va", "vb")}; 89 vy = _mm256_max_ps(vy, vy_min); 90 vy = _mm256_min_ps(vy, vy_max); 91 92 // _mm256_maskstore_ps(y, vmask, vy) could be used here, but triggers msan failures (probably an msan bug). 93 __m128 vy_lo = _mm256_castps256_ps128(vy); 94 if (n & (4 * sizeof(float))) { 95 _mm_storeu_ps(y, vy_lo); 96 vy_lo = _mm256_extractf128_ps(vy, 1); 97 y += 4; 98 } 99 if (n & (2 * sizeof(float))) { 100 _mm_storel_pi((__m64*) y, vy_lo); 101 vy_lo = _mm_movehl_ps(vy_lo, vy_lo); 102 y += 2; 103 } 104 if (n & (1 * sizeof(float))) { 105 _mm_store_ss(y, vy_lo); 106 } 107 } 108} 109