1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert NR % 8 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8#include <assert.h> 9 10#include <immintrin.h> 11 12#include <xnnpack/gemm.h> 13#include <xnnpack/intrinsics-polyfill.h> 14 15 16void xnn_f16_gemm_minmax_ukernel_${MR}x${NR}__avx2_broadcast( 17 size_t mr, 18 size_t nc, 19 size_t kc, 20 const void*restrict a, 21 size_t a_stride, 22 const void*restrict w, 23 void*restrict c, 24 size_t cm_stride, 25 size_t cn_stride, 26 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 27{ 28 assert(mr != 0); 29 assert(mr <= ${MR}); 30 assert(nc != 0); 31 assert(kc != 0); 32 assert(kc % sizeof(uint16_t) == 0); 33 assert(a != NULL); 34 assert(w != NULL); 35 assert(c != NULL); 36 37 const uint16_t* a0 = a; 38 uint16_t* c0 = c; 39 $for M in range(1, MR): 40 const uint16_t* a${M} = (const uint16_t*) ((uintptr_t) a${M-1} + a_stride); 41 uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); 42 $if M % 2 == 0: 43 if XNN_UNPREDICTABLE(mr <= ${M}) { 44 a${M} = a${M-1}; 45 c${M} = c${M-1}; 46 } 47 $elif M + 1 == MR: 48 if XNN_UNPREDICTABLE(mr != ${M+1}) { 49 a${M} = a${M-1}; 50 c${M} = c${M-1}; 51 } 52 $else: 53 if XNN_UNPREDICTABLE(mr < ${M+1}) { 54 a${M} = a${M-1}; 55 c${M} = c${M-1}; 56 } 57 58 do { 59 __m256 vacc0x${ABC[0:8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 60 $for N in range(8, NR, 8): 61 __m256 vacc0x${ABC[N:N+8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + ${N}))); 62 $for M in range(1, MR): 63 $for N in range(0, NR, 8): 64 __m256 vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]}; 65 w = (const uint16_t*) w + ${NR}; 66 67 size_t k = kc; 68 do { 69 $for M in range(MR): 70 const __m256 va${M} = _mm256_cvtph_ps(_mm_set1_epi16((short) *a${M})); 71 a${M} += 1; 72 73 const __m256 vb${ABC[0:8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 74 $for N in range(8, NR, 8): 75 const __m256 vb${ABC[N:N+8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + ${N}))); 76 w = (const uint16_t*) w + ${NR}; 77 78 $for N in range(0, NR, 8): 79 $for M in range(MR): 80 vacc${M}x${ABC[N:N+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va${M}, vb${ABC[N:N+8]}, vacc${M}x${ABC[N:N+8]}), _MM_FROUND_NO_EXC)); 81 82 k -= sizeof(uint16_t); 83 } while (k != 0); 84 85 const __m256 vscale = _mm256_load_ps(params->avx.scale); 86 $for N in range(0, NR, 8): 87 $for M in range(MR): 88 vacc${M}x${ABC[N:N+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vacc${M}x${ABC[N:N+8]}, vscale), _MM_FROUND_NO_EXC)); 89 90 const __m256 vmin = _mm256_load_ps(params->avx.min); 91 $for N in range(0, NR, 8): 92 $for M in range(MR): 93 vacc${M}x${ABC[N:N+8]} = _mm256_max_ps(vacc${M}x${ABC[N:N+8]}, vmin); 94 95 const __m256 vmax = _mm256_load_ps(params->avx.max); 96 $for N in range(0, NR, 8): 97 $for M in range(MR): 98 vacc${M}x${ABC[N:N+8]} = _mm256_min_ps(vacc${M}x${ABC[N:N+8]}, vmax); 99 100 if XNN_LIKELY(nc >= ${NR}) { 101 $for M in range(MR): 102 _mm_storeu_si128((__m128i*) c${M}, _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC)); 103 $for N in range(8, NR, 8): 104 _mm_storeu_si128((__m128i*) (c${M} + ${N}), _mm256_cvtps_ph(vacc${M}x${ABC[N:N+8]}, _MM_FROUND_NO_EXC)); 105 c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride); 106 107 $for M in range(MR): 108 a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc); 109 110 nc -= ${NR}; 111 } else { 112 $for LOG2N in reversed(range(NR.bit_length())): 113 $if LOG2N == 3: 114 $for M in range(MR): 115 __m128i vh${M}x${ABC[0:8]} = _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC); 116 $if NR != 1 << LOG2N: 117 if (nc & ${1 << LOG2N}) { 118 $if LOG2N >= 4: 119 $for M in range(MR): 120 _mm_storeu_si128((__m128i*) c${M}, _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC)); 121 $for N in range(8, 1 << LOG2N, 8): 122 _mm_storeu_si128((__m128i*) (c${M} + ${N}), _mm256_cvtps_ph(vacc${M}x${ABC[N:N+8]}, _MM_FROUND_NO_EXC)); 123 124 $for M in range(MR): 125 $for N in range(0, 1 << (LOG2N - 1), 8): 126 vacc${M}x${ABC[N:N+8]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+8]}; 127 128 $for M in range(MR): 129 c${M} += ${1 << LOG2N}; 130 $elif LOG2N == 3: 131 $for M in range(MR): 132 _mm_storeu_si128((__m128i*) c${M}, vh${M}x${ABC[0:8]}); 133 134 $for M in range(MR): 135 vh${M}x${ABC[0:8]} = _mm256_cvtps_ph(vacc${M}x${ABC[8:16]}, _MM_FROUND_NO_EXC); 136 137 $for M in range(MR): 138 c${M} += ${1 << LOG2N}; 139 $elif LOG2N == 2: 140 $for M in range(MR): 141 _mm_storel_epi64((__m128i*) c${M}, vh${M}x${ABC[0:8]}); 142 143 $for M in range(MR): 144 vh${M}x${ABC[0:8]} = _mm_unpackhi_epi64(vh${M}x${ABC[0:8]}, vh${M}x${ABC[0:8]}); 145 146 $for M in range(MR): 147 c${M} += 4; 148 $elif LOG2N == 1: 149 $for M in range(MR): 150 _mm_storeu_si32(c${M}, vh${M}x${ABC[0:8]}); 151 152 $for M in range(MR): 153 vh${M}x${ABC[0:8]} = _mm_srli_epi64(vh${M}x${ABC[0:8]}, 32); 154 155 $for M in range(MR): 156 c${M} += 2; 157 $elif LOG2N == 0: 158 $for M in range(MR): 159 *c${M} = (uint16_t) _mm_extract_epi16(vh${M}x${ABC[0:8]}, 0); 160 } 161 162 nc = 0; 163 } 164 } while (nc != 0); 165} 166