1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert REQUANTIZATION == "FP32" 8$assert DATATYPE in ["QC8", "QS8"] 9$assert CHANNEL_TILE % 16 == 0 10$assert CHANNEL_TILE >= 16 11$assert KERNEL_TILE >= 2 12#include <assert.h> 13 14#include <immintrin.h> 15 16#include <xnnpack/dwconv.h> 17 18 19$PARAMS_STRUCT = "avx2" if DATATYPE == "QC8" else REQUANTIZATION.lower() + "_avx2" 20$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_qs8_conv_minmax_params" 21void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul16${"_add16" if ADD16 else ""}_vpunpck( 22 size_t channels, 23 size_t output_width, 24 const int8_t** input, 25 const void* weights, 26 int8_t* output, 27 size_t input_stride, 28 size_t output_increment, 29 size_t input_offset, 30 const int8_t* zero, 31 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 32{ 33 assert(channels != 0); 34 assert(output_width != 0); 35 36 do { 37 $for K in range(KERNEL_TILE): 38 const int8_t* i${K} = input[${K}]; 39 assert(i${K} != NULL); 40 if XNN_UNPREDICTABLE(i${K} != zero) { 41 i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset); 42 } 43 input = (const int8_t**) ((uintptr_t) input + input_stride); 44 45 size_t c = channels; 46 const void* w = weights; 47 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 48 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w); 49 $for C in range(8, CHANNEL_TILE, 8): 50 __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t))); 51 52 $for C in range(0, CHANNEL_TILE, 16): 53 __m256i vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_inserti128_si256(vacc${ABC[C:C+8]}, _mm256_castsi256_si128(vacc${ABC[C+8:C+16]}), 1); 54 __m256i vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}, 0x31); 55 56 $for K in range(KERNEL_TILE): 57 58 $for C in range(0, CHANNEL_TILE, 16): 59 $if C == 0: 60 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K})); 61 $else: 62 const __m256i vi${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (i${K} + ${C}))); 63 const __m256i vk${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t)))); 64 i${K} += ${CHANNEL_TILE}; 65 66 $if ADD16: 67 $for C in range(0, CHANNEL_TILE, 16): 68 $if K == 0: 69 __m256i vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 70 $elif K % 2 == 0 or K + 1 == KERNEL_TILE: 71 vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 72 $else: 73 vacc${ABC[C:C+16]} = _mm256_add_epi16(vacc${ABC[C:C+16]}, _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]})); 74 75 $if K % 2 == 1 or K + 1 == KERNEL_TILE: 76 $for C in range(0, CHANNEL_TILE, 16): 77 $if K == 1: 78 __m256i vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15); 79 $else: 80 vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15); 81 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]})); 82 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]})); 83 $else: 84 $for C in range(0, CHANNEL_TILE, 16): 85 const __m256i vprod${K}x${ABC[C:C+16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}); 86 const __m256i vprod${K}x${ABC[C:C+16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[C:C+16]}lo, 15); 87 88 $for C in range(0, CHANNEL_TILE, 16): 89 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi)); 90 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi)); 91 92 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t)); 93 94 $for C in range(0, CHANNEL_TILE, 16): 95 vacc${ABC[C:C+8]} = _mm256_inserti128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_castsi256_si128(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}), 1); 96 vacc${ABC[C+8:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 0x31); 97 98 $for C in range(0, CHANNEL_TILE, 8): 99 __m256 vfpacc${ABC[C:C+8]} = _mm256_cvtepi32_ps(vacc${ABC[C:C+8]}); 100 101 $if DATATYPE == "QC8": 102 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) w); 103 $for C in range(8, CHANNEL_TILE, 8): 104 const __m256 vscale${ABC[C:C+8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${C} * sizeof(float))); 105 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(float)); 106 $for C in range(0, CHANNEL_TILE, 8): 107 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale${ABC[C:C+8]}); 108 $else: 109 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale); 110 $for C in range(0, CHANNEL_TILE, 8): 111 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale); 112 113 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); 114 $for C in range(0, CHANNEL_TILE, 8): 115 vfpacc${ABC[C:C+8]} = _mm256_min_ps(vfpacc${ABC[C:C+8]}, voutput_max_less_zero_point); 116 117 $for C in range(0, CHANNEL_TILE, 8): 118 vacc${ABC[C:C+8]} = _mm256_cvtps_epi32(vfpacc${ABC[C:C+8]}); 119 120 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); 121 $for C in range(0, CHANNEL_TILE, 16): 122 const __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point); 123 124 $for C in range(0, CHANNEL_TILE, 16): 125 __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0)); 126 127 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 128 $for C in range(0, CHANNEL_TILE, 16): 129 vout${ABC[C:C+16]} = _mm_max_epi8(vout${ABC[C:C+16]}, voutput_min); 130 131 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 132 $for C in range(16, CHANNEL_TILE, 16): 133 _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]}); 134 output += ${CHANNEL_TILE}; 135 } 136 if XNN_UNLIKELY(c != 0) { 137 $if CHANNEL_TILE > 16: 138 const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)); 139 ${"do " if CHANNEL_TILE > 16 else ""}{ 140 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w); 141 __m256i vacc${ABC[8:16]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + 8 * sizeof(int32_t))); 142 143 __m256i vacc${ABC[0:4]}${ABC[8:12]} = _mm256_inserti128_si256(vacc${ABC[0:8]}, _mm256_castsi256_si128(vacc${ABC[8:16]}), 1); 144 __m256i vacc${ABC[4:8]}${ABC[12:16]} = _mm256_permute2x128_si256(vacc${ABC[0:8]}, vacc${ABC[8:16]}, 0x31); 145 146 $for K in range(KERNEL_TILE): 147 148 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K})); 149 $if CHANNEL_TILE > 16: 150 $if K == 0: 151 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) k)); 152 $else: 153 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE}))); 154 $else: 155 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t)))); 156 $if CHANNEL_TILE > 16: 157 i${K} += 16; 158 159 const __m256i vprod${K}x${ABC[0:16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}); 160 const __m256i vprod${K}x${ABC[0:16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[0:16]}lo, 15); 161 162 vacc${ABC[0:4]}${ABC[8:12]} = _mm256_add_epi32(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi)); 163 vacc${ABC[4:8]}${ABC[12:16]} = _mm256_add_epi32(vacc${ABC[4:8]}${ABC[12:16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi)); 164 165 vacc${ABC[0:8]} = _mm256_inserti128_si256(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_castsi256_si128(vacc${ABC[4:8]}${ABC[12:16]}), 1); 166 vacc${ABC[8:16]} = _mm256_permute2x128_si256(vacc${ABC[0:4]}${ABC[8:12]}, vacc${ABC[4:8]}${ABC[12:16]}, 0x31); 167 168 $if CHANNEL_TILE > 16: 169 k += 16; 170 171 __m256 vfpacc${ABC[0:8]} = _mm256_cvtepi32_ps(vacc${ABC[0:8]}); 172 __m256 vfpacc${ABC[8:16]} = _mm256_cvtepi32_ps(vacc${ABC[8:16]}); 173 174 $if DATATYPE == "QC8": 175 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t))); 176 const __m256 vscale${ABC[8:16]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 8 * sizeof(float))); 177 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale${ABC[0:8]}); 178 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale${ABC[8:16]}); 179 $else: 180 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale); 181 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale); 182 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale); 183 184 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); 185 vfpacc${ABC[0:8]} = _mm256_min_ps(vfpacc${ABC[0:8]}, voutput_max_less_zero_point); 186 vfpacc${ABC[8:16]} = _mm256_min_ps(vfpacc${ABC[8:16]}, voutput_max_less_zero_point); 187 188 vacc${ABC[0:8]} = _mm256_cvtps_epi32(vfpacc${ABC[0:8]}); 189 vacc${ABC[8:16]} = _mm256_cvtps_epi32(vfpacc${ABC[8:16]}); 190 191 $if CHANNEL_TILE > 16: 192 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t)); 193 194 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point); 195 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point); 196 __m128i vout${ABC[8:16]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[8:16]}), _mm256_extracti128_si256(vacc${ABC[8:16]}, 1)), voutput_zero_point); 197 198 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 199 200 __m128i vout${ABC[0:16]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[8:16]}); 201 vout${ABC[0:16]} = _mm_max_epi8(vout${ABC[0:16]}, voutput_min); 202 203 $if CHANNEL_TILE > 16: 204 if XNN_LIKELY(c >= 16) { 205 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 206 output += 16; 207 c -= 16; 208 } else { 209 if (c & 8) { 210 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]}); 211 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]}); 212 output += 8; 213 } 214 if (c & 4) { 215 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]}); 216 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32); 217 output += 4; 218 } 219 if (c & 2) { 220 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0); 221 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16); 222 output += 2; 223 } 224 if (c & 1) { 225 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0); 226 output += 1; 227 } 228 c = 0; 229 } 230 $else: 231 if (c & 8) { 232 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]}); 233 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]}); 234 output += 8; 235 } 236 if (c & 4) { 237 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]}); 238 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32); 239 output += 4; 240 } 241 if (c & 2) { 242 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0); 243 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16); 244 output += 2; 245 } 246 if (c & 1) { 247 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0); 248 output += 1; 249 } 250 }${" while (c != 0);" if CHANNEL_TILE > 16 else ""} 251 } 252 253 output = (int8_t*) ((uintptr_t) output + output_increment); 254 } while (--output_width != 0); 255} 256