1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert KERNEL_TILE >= 2 8$assert ACCUMULATORS >= 1 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/dwconv.h> 15#include <xnnpack/intrinsics-polyfill.h> 16 17 18void xnn_f16_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__fma3${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 19 size_t channels, 20 size_t output_width, 21 const void** input, 22 const void* weights, 23 void* output, 24 size_t input_stride, 25 size_t output_increment, 26 size_t input_offset, 27 const void* zero, 28 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 29{ 30 assert(channels != 0); 31 assert(output_width != 0); 32 33 const __m256 vmax = _mm256_load_ps(params->avx.max); 34 const __m256 vmin = _mm256_load_ps(params->avx.min); 35 36 uint16_t* o = (uint16_t*) output; 37 do { 38 $for K in range(KERNEL_TILE): 39 const uint16_t* i${K} = input[${K}]; 40 assert(i${K} != NULL); 41 if XNN_UNPREDICTABLE(i${K} != zero) { 42 i${K} = (const uint16_t*) ((uintptr_t) i${K} + input_offset); 43 } 44 input = (const void**) ((uintptr_t) input + input_stride); 45 46 size_t c = channels; 47 const uint16_t* w = weights; 48 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 49 __m256 vacc${ABC[0:8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 50 $for C in range(8, CHANNEL_TILE, 8): 51 __m256 vacc${ABC[C:C+8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${C}))); 52 53 $for K in range(KERNEL_TILE): 54 55 const __m256 vi${K}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K})); 56 $for C in range(8, CHANNEL_TILE, 8): 57 const __m256 vi${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${K} + ${C}))); 58 i${K} += ${CHANNEL_TILE}; 59 60 $for C in range(0, CHANNEL_TILE, 8): 61 const __m256 vk${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE + C}))); 62 $for C in range(0, CHANNEL_TILE, 8): 63 $if 1 <= K < ACCUMULATORS: 64 __m256 vacc${ABC[C:C+8]}p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 65 $else: 66 vacc${ABC[C:C+8]}p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}, vacc${ABC[C:C+8]}p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC)); 67 68 w += ${(KERNEL_TILE + 1) * CHANNEL_TILE}; 69 70 $if ACCUMULATORS > 1: 71 // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0 72 $ACC_SLICE = 1 73 $while ACC_SLICE < ACCUMULATORS: 74 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 75 $if A + ACC_SLICE < ACCUMULATORS: 76 $for C in range(0, CHANNEL_TILE, 8): 77 vacc${ABC[C:C+8]}p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc${ABC[C:C+8]}p${A}, vacc${ABC[C:C+8]}p${A + ACC_SLICE}), _MM_FROUND_NO_EXC)); 78 $ACC_SLICE *= 2 79 80 $for C in range(0, CHANNEL_TILE, 8): 81 __m256 vacc${ABC[C:C+8]} = _mm256_max_ps(vacc${ABC[C:C+8]}p0, vmin); 82 $for C in range(0, CHANNEL_TILE, 8): 83 vacc${ABC[C:C+8]} = _mm256_min_ps(vacc${ABC[C:C+8]}, vmax); 84 85 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc${ABC[0:8]}, _MM_FROUND_NO_EXC)); 86 $for C in range(8, CHANNEL_TILE, 8): 87 _mm_storeu_si128((__m128i*) (o + ${C}), _mm256_cvtps_ph(vacc${ABC[C:C+8]}, _MM_FROUND_NO_EXC)); 88 o += ${CHANNEL_TILE}; 89 } 90 $if CHANNEL_TILE > 8: 91 for (; c >= 8; c -= 8) { 92 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 93 $for K in range(KERNEL_TILE): 94 95 const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K})); 96 i${K} += 8; 97 98 const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE}))); 99 $if 1 <= K < ACCUMULATORS: 100 __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC)); 101 $else: 102 vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC)); 103 104 w += 8; 105 106 $if ACCUMULATORS > 1: 107 // Add up all accumulators to vacc${ABC[0:8]}p0 108 $ACC_SLICE = 1 109 $while ACC_SLICE < ACCUMULATORS: 110 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 111 $if A + ACC_SLICE < ACCUMULATORS: 112 vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC)); 113 $ACC_SLICE *= 2 114 115 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin); 116 vacc01234567 = _mm256_min_ps(vacc01234567, vmax); 117 118 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC)); 119 o += 8; 120 } 121 if XNN_UNLIKELY(c != 0) { 122 assert(c >= 1); 123 assert(c <= 7); 124 125 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 126 $for K in range(KERNEL_TILE): 127 128 const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K})); 129 130 const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE}))); 131 $if 1 <= K < ACCUMULATORS: 132 __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC)); 133 $else: 134 vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC)); 135 136 $if ACCUMULATORS > 1: 137 // Add up all accumulators to vacc${ABC[0:8]}p0 138 $ACC_SLICE = 1 139 $while ACC_SLICE < ACCUMULATORS: 140 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 141 $if A + ACC_SLICE < ACCUMULATORS: 142 vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC)); 143 $ACC_SLICE *= 2 144 145 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin); 146 vacc01234567 = _mm256_min_ps(vacc01234567, vmax); 147 148 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC); 149 if (c & 4) { 150 _mm_storel_epi64((__m128i*) o, vh01234567); 151 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567); 152 o += 4; 153 } 154 if (c & 2) { 155 _mm_storeu_si32(o, vh01234567); 156 vh01234567 = _mm_srli_epi64(vh01234567, 32); 157 o += 2; 158 } 159 if (c & 1) { 160 *o = (uint16_t) _mm_extract_epi16(vh01234567, 0); 161 o += 1; 162 } 163 } 164 165 o = (uint16_t*) ((uintptr_t) o + output_increment); 166 } while (--output_width != 0); 167} 168