1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert REQUANTIZATION == "FP32" 8$assert DATATYPE in ["QC8", "QS8"] 9$assert CHANNEL_TILE % 16 == 0 10$assert CHANNEL_TILE >= 16 11$assert KERNEL_TILE >= 2 12#include <assert.h> 13 14#include <immintrin.h> 15 16#include <xnnpack/dwconv.h> 17#include <xnnpack/unaligned.h> 18 19 20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_avx2" 21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul16${"_add16" if ADD16 else ""}_vpunpck( 23 size_t channels, 24 size_t output_width, 25 const int8_t** input, 26 const void* weights, 27 int8_t* output, 28 size_t input_stride, 29 size_t output_increment, 30 size_t input_offset, 31 const int8_t* zero, 32 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 33{ 34 assert(channels != 0); 35 assert(output_width != 0); 36 37 do { 38 $for K in range(KERNEL_TILE): 39 const int8_t* i${K} = input[${K}]; 40 assert(i${K} != NULL); 41 if XNN_UNPREDICTABLE(i${K} != zero) { 42 i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset); 43 } 44 input = (const int8_t**) ((uintptr_t) input + input_stride); 45 46 size_t c = channels; 47 const void* w = weights; 48 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 49 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w); 50 $for C in range(8, CHANNEL_TILE, 8): 51 __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t))); 52 53 $for C in range(0, CHANNEL_TILE, 16): 54 __m256i vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_inserti128_si256(vacc${ABC[C:C+8]}, _mm256_castsi256_si128(vacc${ABC[C+8:C+16]}), 1); 55 __m256i vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}, 0x31); 56 57 $for K in range(KERNEL_TILE): 58 59 $for C in range(0, CHANNEL_TILE, 16): 60 $if C == 0: 61 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K})); 62 $else: 63 const __m256i vi${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (i${K} + ${C}))); 64 const __m256i vk${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t)))); 65 i${K} += ${CHANNEL_TILE}; 66 67 $if ADD16: 68 $for C in range(0, CHANNEL_TILE, 16): 69 $if K == 0: 70 __m256i vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 71 $elif K % 2 == 0 or K + 1 == KERNEL_TILE: 72 vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 73 $else: 74 vacc${ABC[C:C+16]} = _mm256_add_epi16(vacc${ABC[C:C+16]}, _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]})); 75 76 $if K % 2 == 1 or K + 1 == KERNEL_TILE: 77 $for C in range(0, CHANNEL_TILE, 16): 78 $if K == 1: 79 __m256i vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15); 80 $else: 81 vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15); 82 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]})); 83 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]})); 84 $else: 85 $for C in range(0, CHANNEL_TILE, 16): 86 const __m256i vprod${K}x${ABC[C:C+16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 87 const __m256i vprod${K}x${ABC[C:C+16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[C:C+16]}lo, 15); 88 89 $for C in range(0, CHANNEL_TILE, 16): 90 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi)); 91 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi)); 92 93 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t)); 94 95 $for C in range(0, CHANNEL_TILE, 16): 96 vacc${ABC[C:C+8]} = _mm256_inserti128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_castsi256_si128(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}), 1); 97 vacc${ABC[C+8:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 0x31); 98 99 $for C in range(0, CHANNEL_TILE, 8): 100 __m256 vfpacc${ABC[C:C+8]} = _mm256_cvtepi32_ps(vacc${ABC[C:C+8]}); 101 102 $if DATATYPE == "QC8": 103 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) w); 104 $for C in range(8, CHANNEL_TILE, 8): 105 const __m256 vscale${ABC[C:C+8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${C} * sizeof(float))); 106 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(float)); 107 $for C in range(0, CHANNEL_TILE, 8): 108 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale${ABC[C:C+8]}); 109 $else: 110 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale); 111 $for C in range(0, CHANNEL_TILE, 8): 112 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale); 113 114 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); 115 $for C in range(0, CHANNEL_TILE, 8): 116 vfpacc${ABC[C:C+8]} = _mm256_min_ps(vfpacc${ABC[C:C+8]}, voutput_max_less_zero_point); 117 118 $for C in range(0, CHANNEL_TILE, 8): 119 vacc${ABC[C:C+8]} = _mm256_cvtps_epi32(vfpacc${ABC[C:C+8]}); 120 121 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); 122 $for C in range(0, CHANNEL_TILE, 16): 123 const __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point); 124 125 $for C in range(0, CHANNEL_TILE, 16): 126 __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0)); 127 128 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 129 $for C in range(0, CHANNEL_TILE, 16): 130 vout${ABC[C:C+16]} = _mm_max_epi8(vout${ABC[C:C+16]}, voutput_min); 131 132 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 133 $for C in range(16, CHANNEL_TILE, 16): 134 _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]}); 135 output += ${CHANNEL_TILE}; 136 } 137 if XNN_UNLIKELY(c != 0) { 138 $if CHANNEL_TILE > 16: 139 const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)); 140 ${"do " if CHANNEL_TILE > 16 else ""}{ 141 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w); 142 __m256i vacc${ABC[8:16]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + 8 * sizeof(int32_t))); 143 144 __m256i vacc${ABC[0:4]}${ABC[8:12]} = _mm256_inserti128_si256(vacc${ABC[0:8]}, _mm256_castsi256_si128(vacc${ABC[8:16]}), 1); 145 __m256i vacc${ABC[4:8]}${ABC[12:16]} = _mm256_permute2x128_si256(vacc${ABC[0:8]}, vacc${ABC[8:16]}, 0x31); 146 147 $for K in range(KERNEL_TILE): 148 149 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K})); 150 $if CHANNEL_TILE > 16: 151 $if K == 0: 152 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) k)); 153 $else: 154 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE}))); 155 $else: 156 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t)))); 157 $if CHANNEL_TILE > 16: 158 i${K} += 16; 159 160 const __m256i vprod${K}x${ABC[0:16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}); 161 const __m256i vprod${K}x${ABC[0:16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[0:16]}lo, 15); 162 163 vacc${ABC[0:4]}${ABC[8:12]} = _mm256_add_epi32(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi)); 164 vacc${ABC[4:8]}${ABC[12:16]} = _mm256_add_epi32(vacc${ABC[4:8]}${ABC[12:16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi)); 165 166 vacc${ABC[0:8]} = _mm256_inserti128_si256(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_castsi256_si128(vacc${ABC[4:8]}${ABC[12:16]}), 1); 167 vacc${ABC[8:16]} = _mm256_permute2x128_si256(vacc${ABC[0:4]}${ABC[8:12]}, vacc${ABC[4:8]}${ABC[12:16]}, 0x31); 168 169 $if CHANNEL_TILE > 16: 170 k += 16; 171 172 __m256 vfpacc${ABC[0:8]} = _mm256_cvtepi32_ps(vacc${ABC[0:8]}); 173 __m256 vfpacc${ABC[8:16]} = _mm256_cvtepi32_ps(vacc${ABC[8:16]}); 174 175 $if DATATYPE == "QC8": 176 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t))); 177 const __m256 vscale${ABC[8:16]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 8 * sizeof(float))); 178 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale${ABC[0:8]}); 179 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale${ABC[8:16]}); 180 $else: 181 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale); 182 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale); 183 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale); 184 185 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); 186 vfpacc${ABC[0:8]} = _mm256_min_ps(vfpacc${ABC[0:8]}, voutput_max_less_zero_point); 187 vfpacc${ABC[8:16]} = _mm256_min_ps(vfpacc${ABC[8:16]}, voutput_max_less_zero_point); 188 189 vacc${ABC[0:8]} = _mm256_cvtps_epi32(vfpacc${ABC[0:8]}); 190 vacc${ABC[8:16]} = _mm256_cvtps_epi32(vfpacc${ABC[8:16]}); 191 192 $if CHANNEL_TILE > 16: 193 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t)); 194 195 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point); 196 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point); 197 __m128i vout${ABC[8:16]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[8:16]}), _mm256_extracti128_si256(vacc${ABC[8:16]}, 1)), voutput_zero_point); 198 199 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 200 201 __m128i vout${ABC[0:16]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[8:16]}); 202 vout${ABC[0:16]} = _mm_max_epi8(vout${ABC[0:16]}, voutput_min); 203 204 $if CHANNEL_TILE > 16: 205 if XNN_LIKELY(c >= 16) { 206 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 207 output += 16; 208 c -= 16; 209 } else { 210 if (c & 8) { 211 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]}); 212 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]}); 213 output += 8; 214 } 215 if (c & 4) { 216 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]})); 217 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32); 218 output += 4; 219 } 220 if (c & 2) { 221 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0)); 222 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16); 223 output += 2; 224 } 225 if (c & 1) { 226 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0); 227 output += 1; 228 } 229 c = 0; 230 } 231 $else: 232 if (c & 8) { 233 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]}); 234 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]}); 235 output += 8; 236 } 237 if (c & 4) { 238 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]})); 239 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32); 240 output += 4; 241 } 242 if (c & 2) { 243 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0)); 244 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16); 245 output += 2; 246 } 247 if (c & 1) { 248 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0); 249 output += 1; 250 } 251 }${" while (c != 0);" if CHANNEL_TILE > 16 else ""} 252 } 253 254 output = (int8_t*) ((uintptr_t) output + output_increment); 255 } while (--output_width != 0); 256} 257