1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert ROW_TILE >= 1 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/math.h> 15#include <xnnpack/prelu.h> 16 17 18void xnn_f16_prelu_ukernel__f16c_${ROW_TILE}x${CHANNEL_TILE}( 19 size_t rows, 20 size_t channels, 21 const void* restrict input, 22 size_t input_stride, 23 const void* restrict weights, 24 void* restrict output, 25 size_t output_stride) XNN_OOB_READS 26{ 27 assert(rows != 0); 28 assert(channels != 0); 29 assert(channels % sizeof(uint16_t) == 0); 30 31 const uint16_t* i0 = (const uint16_t*) input; 32 uint16_t* o0 = (uint16_t*) output; 33 $for M in range(1, ROW_TILE): 34 const uint16_t* i${M} = (const uint16_t*) ((uintptr_t) i${M-1} + input_stride); 35 uint16_t* o${M} = (uint16_t*) ((uintptr_t) o${M-1} + output_stride); 36 37 const size_t input_increment = input_stride * ${ROW_TILE} - channels; 38 const size_t output_increment = output_stride * ${ROW_TILE} - channels; 39 40 do { 41 $for M in range(1, ROW_TILE): 42 $if M % 2 == 0: 43 if XNN_UNPREDICTABLE(rows <= ${M}) { 44 i${M} = i${M-1}; 45 o${M} = o${M-1}; 46 } 47 $else: 48 if XNN_UNPREDICTABLE(rows < ${M+1}) { 49 i${M} = i${M-1}; 50 o${M} = o${M-1}; 51 } 52 53 const uint16_t* w = (const uint16_t*) weights; 54 size_t c = channels; 55 $if CHANNEL_TILE > 8: 56 for (; c >= ${CHANNEL_TILE} * sizeof(uint16_t); c -= ${CHANNEL_TILE} * sizeof(uint16_t)) { 57 const __m256 vw${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); 58 $for C in range(8, CHANNEL_TILE, 8): 59 const __m256 vw${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + ${C}))); 60 w += ${CHANNEL_TILE}; 61 62 $for M in range(ROW_TILE): 63 const __m256 vi${M}x0${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); 64 $for C in range(8, CHANNEL_TILE, 8): 65 const __m256 vi${M}x0${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${M} + ${C}))); 66 i${M} += ${CHANNEL_TILE}; 67 68 $for M in range(ROW_TILE): 69 $for C in range(0, CHANNEL_TILE, 8): 70 __m256 vacc${M}x0${ABC[C:C+8]} = _mm256_mul_ps(vi${M}x0${ABC[C:C+8]}, vw${ABC[C:C+8]}); 71 72 $for M in range(ROW_TILE): 73 $for C in range(0, CHANNEL_TILE, 8): 74 vacc${M}x0${ABC[C:C+8]} = _mm256_blendv_ps(vi${M}x0${ABC[C:C+8]}, vacc${M}x0${ABC[C:C+8]}, vi${M}x0${ABC[C:C+8]}); 75 76 $for M in range(ROW_TILE): 77 _mm_storeu_si128((__m128i*) o${M}, _mm256_cvtps_ph(vacc${M}x0${ABC[C:C+8]}, _MM_FROUND_NO_EXC)); 78 $for C in range(0, CHANNEL_TILE, 8): 79 _mm_storeu_si128((__m128i*) (o${M} + ${C}), _mm256_cvtps_ph(vacc${M}x0${ABC[C:C+8]}, _MM_FROUND_NO_EXC)); 80 o${M} += ${CHANNEL_TILE}; 81 } 82 for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) { 83 const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); 84 w += 8; 85 86 $for M in range(ROW_TILE): 87 const __m256 vi${M}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); 88 i${M} += 8; 89 90 $for M in range(ROW_TILE): 91 __m256 vacc${M}x01234567 = _mm256_mul_ps(vi${M}x01234567, vw01234567); 92 93 $for M in range(ROW_TILE): 94 vacc${M}x01234567 = _mm256_blendv_ps(vi${M}x01234567, vacc${M}x01234567, vi${M}x01234567); 95 96 $for M in range(ROW_TILE): 97 _mm_storeu_si128((__m128i*) o${M}, _mm256_cvtps_ph(vacc${M}x01234567, _MM_FROUND_NO_EXC)); 98 o${M} += 8; 99 } 100 if XNN_UNLIKELY(c != 0) { 101 const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); 102 103 $for M in range(ROW_TILE): 104 const __m256 vi${M}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); 105 i${M} = (const uint16_t*) ((uintptr_t) i${M} + c); 106 107 $for M in range(ROW_TILE): 108 __m256 vacc${M}x01234567 = _mm256_mul_ps(vi${M}x01234567, vw01234567); 109 110 $for M in range(ROW_TILE): 111 vacc${M}x01234567 = _mm256_blendv_ps(vi${M}x01234567, vacc${M}x01234567, vi${M}x01234567); 112 113 $for M in range(ROW_TILE): 114 __m128i vh${M}x01234567 = _mm256_cvtps_ph(vacc${M}x01234567, _MM_FROUND_NO_EXC); 115 if (c & (4 * sizeof(uint16_t))) { 116 $for M in range(ROW_TILE): 117 _mm_storel_epi64((__m128i*) o${M}, vh${M}x01234567); 118 119 $for M in range(ROW_TILE): 120 vh${M}x01234567 = _mm_unpackhi_epi64(vh${M}x01234567, vh${M}x01234567); 121 122 $for M in range(ROW_TILE): 123 o${M} += 4; 124 } 125 if (c & (2 * sizeof(uint16_t))) { 126 $for M in range(ROW_TILE): 127 *((uint32_t*) o${M}) = (uint32_t) _mm_cvtsi128_si32(vh${M}x01234567); 128 129 $for M in range(ROW_TILE): 130 vh${M}x01234567 = _mm_srli_epi64(vh${M}x01234567, 32); 131 132 $for M in range(ROW_TILE): 133 o${M} += 2; 134 } 135 if (c & (1 * sizeof(uint16_t))) { 136 $for M in range(ROW_TILE): 137 *o${M} = (uint16_t) _mm_extract_epi16(vh${M}x01234567, 0); 138 139 $for M in range(ROW_TILE): 140 o${M} += 1; 141 } 142 } 143 $for M in range(ROW_TILE): 144 i${M} = (const uint16_t*) ((uintptr_t) i${M} + input_increment); 145 o${M} = (uint16_t*) ((uintptr_t) o${M} + output_increment); 146 rows = doz(rows, ${ROW_TILE}); 147 } while (rows != 0); 148} 149