1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert CHANNEL_TILE % 8 == 0 8$assert CHANNEL_TILE >= 8 9$assert ROW_TILE >= 3 10$assert REQUANTIZATION == "FP32" 11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 12#include <assert.h> 13 14#include <wasm_simd128.h> 15 16#include <xnnpack/gavgpool.h> 17 18 19$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 20$WASM_X16X8_LOAD8X8 = {"QS8": "wasm_i16x8_load8x8", "QU8": "wasm_u16x8_load8x8"}[DATATYPE] 21$WASM_X32X4_EXTEND_LOW_X16X8 = {"QS8": "wasm_i32x4_extend_low_i16x8", "QU8": "wasm_u32x4_extend_low_u16x8"}[DATATYPE] 22$WASM_X32X4_EXTEND_HIGH_X16X8 = {"QS8": "wasm_i32x4_extend_high_i16x8", "QU8": "wasm_u32x4_extend_high_u16x8"}[DATATYPE] 23$WASM_X8X16_NARROW_I16X8 = {"QS8": "wasm_i8x16_narrow_i16x8", "QU8": "wasm_u8x16_narrow_i16x8"}[DATATYPE] 24$WASM_X8X16_MIN = {"QS8": "wasm_i8x16_min", "QU8": "wasm_u8x16_min"}[DATATYPE] 25void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}x__wasmsimd_c${CHANNEL_TILE}( 26 size_t rows, 27 size_t channels, 28 const ${XINT8_T}* input, 29 size_t input_stride, 30 const ${XINT8_T}* zero, 31 ${XINT8_T}* output, 32 const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 33{ 34 assert(rows != 0); 35 assert(rows <= ${ROW_TILE}); 36 assert(channels != 0); 37 38 const ${XINT8_T}* i0 = input; 39 $for M in range(1, ROW_TILE): 40 const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride); 41 $if M % 2 == 1: 42 if XNN_UNPREDICTABLE(rows < ${M+1}) { 43 i${M} = zero; 44 } 45 $else: 46 if XNN_UNPREDICTABLE(rows <= ${M}) { 47 i${M} = zero; 48 } 49 50 const v128_t vinit_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.init_bias); 51 const v128_t vscale = wasm_v128_load64_splat(params->fp32_wasmsimd.scale); 52 const v128_t vmagic_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias); 53 const v128_t vmagic_min = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_min); 54 const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias_less_output_zero_point); 55 const v128_t voutput_max = wasm_v128_load64_splat(params->fp32_wasmsimd.output_max); 56 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 57 $for M in range(2): 58 const v128_t vxi${M}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${M}); 59 $for C in range(8, CHANNEL_TILE, 8): 60 const v128_t vxi${M}x${ABC[C:C+8]} = ${WASM_X16X8_LOAD8X8}(i${M} + ${C}); 61 i${M} += ${CHANNEL_TILE}; 62 63 v128_t vacc${ABC[0:8]} = wasm_i16x8_add(vxi0x${ABC[0:8]}, vxi1x${ABC[0:8]}); 64 const v128_t vxi2x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i2); 65 $for C in range(8, CHANNEL_TILE, 8): 66 v128_t vacc${ABC[C:C+8]} = wasm_i16x8_add(vxi0x${ABC[C:C+8]}, vxi1x${ABC[C:C+8]}); 67 const v128_t vxi2x${ABC[C:C+8]} = ${WASM_X16X8_LOAD8X8}(i2 + ${C}); 68 i2 += ${CHANNEL_TILE}; 69 70 $for M in range(3, ROW_TILE): 71 vacc${ABC[0:8]} = wasm_i16x8_add(vacc${ABC[0:8]}, vxi${M-1}x${ABC[0:8]}); 72 const v128_t vxi${M}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${M}); 73 $for C in range(8, CHANNEL_TILE, 8): 74 vacc${ABC[C:C+8]} = wasm_i16x8_add(vacc${ABC[C:C+8]}, vxi${M-1}x${ABC[C:C+8]}); 75 const v128_t vxi${M}x${ABC[C:C+8]} = ${WASM_X16X8_LOAD8X8}(i${M} + ${C}); 76 i${M} += ${CHANNEL_TILE}; 77 78 $for C in range(0, CHANNEL_TILE, 8): 79 vacc${ABC[C:C+8]} = wasm_i16x8_add(vacc${ABC[C:C+8]}, vxi${ROW_TILE-1}x${ABC[C:C+8]}); 80 81 $for C in range(0, CHANNEL_TILE, 8): 82 v128_t vacc${ABC[C:C+4]} = wasm_i32x4_add(vinit_bias, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc${ABC[C:C+8]})); 83 v128_t vacc${ABC[C+4:C+8]} = wasm_i32x4_add(vinit_bias, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc${ABC[C:C+8]})); 84 85 $for C in range(0, CHANNEL_TILE, 4): 86 vacc${ABC[C:C+4]} = wasm_f32x4_convert_i32x4(vacc${ABC[C:C+4]}); 87 88 $for C in range(0, CHANNEL_TILE, 4): 89 vacc${ABC[C:C+4]} = wasm_f32x4_mul(vacc${ABC[C:C+4]}, vscale); 90 91 $for C in range(0, CHANNEL_TILE, 4): 92 vacc${ABC[C:C+4]} = wasm_f32x4_add(vacc${ABC[C:C+4]}, vmagic_bias); 93 94 $for C in range(0, CHANNEL_TILE, 4): 95 vacc${ABC[C:C+4]} = wasm_i32x4_max(vacc${ABC[C:C+4]}, vmagic_min); 96 97 $for C in range(0, CHANNEL_TILE, 4): 98 vacc${ABC[C:C+4]} = wasm_i32x4_sub(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point); 99 100 $for C in range(0, CHANNEL_TILE, 8): 101 v128_t vout${ABC[C:C+8]} = wasm_i16x8_narrow_i32x4(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]}); 102 103 $for C in range(0, CHANNEL_TILE, 16): 104 $if C + 8 < CHANNEL_TILE: 105 v128_t vout${ABC[C:C+16]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[C:C+8]}, vout${ABC[C+8:C+16]}); 106 $else: 107 v128_t vout${ABC[C:C+8]}${ABC[C:C+8]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[C:C+8]}, vout${ABC[C:C+8]}); 108 109 $for C in range(0, CHANNEL_TILE, 16): 110 $if C + 8 < CHANNEL_TILE: 111 vout${ABC[C:C+16]} = ${WASM_X8X16_MIN}(vout${ABC[C:C+16]}, voutput_max); 112 $else: 113 vout${ABC[C:C+8]}${ABC[C:C+8]} = ${WASM_X8X16_MIN}(vout${ABC[C:C+8]}${ABC[C:C+8]}, voutput_max); 114 115 $if CHANNEL_TILE > 8: 116 wasm_v128_store(output, vout${ABC[0:16]}); 117 $else: 118 *((double*) output) = wasm_f64x2_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 119 $for C in range(16, CHANNEL_TILE, 16): 120 $if C + 8 < CHANNEL_TILE: 121 wasm_v128_store(output + ${C}, vout${ABC[C:C+16]}); 122 $else: 123 *((double*) (output + ${C})) = wasm_f64x2_extract_lane(vout${ABC[C:C+8]}${ABC[C:C+8]}, 0); 124 output += ${CHANNEL_TILE}; 125 } 126 if XNN_UNLIKELY(channels != 0) { 127 ${"do " if CHANNEL_TILE > 8 else ""}{ 128 $for M in range(2): 129 const v128_t vxi${M}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${M}); 130 i${M} += 8; 131 132 v128_t vacc${ABC[0:8]} = wasm_i16x8_add(vxi0x${ABC[0:8]}, vxi1x${ABC[0:8]}); 133 const v128_t vxi2x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i2); 134 i2 += 8; 135 136 $for M in range(3, ROW_TILE): 137 vacc${ABC[0:8]} = wasm_i16x8_add(vacc${ABC[0:8]}, vxi${M-1}x${ABC[0:8]}); 138 const v128_t vxi${M}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${M}); 139 i${M} += 8; 140 141 vacc${ABC[0:8]} = wasm_i16x8_add(vacc${ABC[0:8]}, vxi${ROW_TILE-1}x${ABC[0:8]}); 142 143 v128_t vacc${ABC[0:4]} = wasm_i32x4_add(vinit_bias, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc${ABC[0:8]})); 144 v128_t vacc${ABC[4:8]} = wasm_i32x4_add(vinit_bias, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc${ABC[0:8]})); 145 146 vacc${ABC[0:4]} = wasm_f32x4_convert_i32x4(vacc${ABC[0:4]}); 147 vacc${ABC[4:8]} = wasm_f32x4_convert_i32x4(vacc${ABC[4:8]}); 148 149 vacc${ABC[0:4]} = wasm_f32x4_mul(vacc${ABC[0:4]}, vscale); 150 vacc${ABC[4:8]} = wasm_f32x4_mul(vacc${ABC[4:8]}, vscale); 151 152 vacc${ABC[0:4]} = wasm_f32x4_add(vacc${ABC[0:4]}, vmagic_bias); 153 vacc${ABC[4:8]} = wasm_f32x4_add(vacc${ABC[4:8]}, vmagic_bias); 154 155 vacc${ABC[0:4]} = wasm_i32x4_max(vacc${ABC[0:4]}, vmagic_min); 156 vacc${ABC[4:8]} = wasm_i32x4_max(vacc${ABC[4:8]}, vmagic_min); 157 158 vacc${ABC[0:4]} = wasm_i32x4_sub(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point); 159 vacc${ABC[4:8]} = wasm_i32x4_sub(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point); 160 161 const v128_t vout${ABC[0:8]} = wasm_i16x8_narrow_i32x4(vacc${ABC[0:4]}, vacc${ABC[4:8]}); 162 v128_t vout${ABC[0:8]}${ABC[0:8]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[0:8]}, vout${ABC[0:8]}); 163 vout${ABC[0:8]}${ABC[0:8]} = ${WASM_X8X16_MIN}(vout${ABC[0:8]}${ABC[0:8]}, voutput_max); 164 165 $if CHANNEL_TILE > 8: 166 if XNN_LIKELY(channels >= 8) { 167 *((double*) output) = wasm_f64x2_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 168 output += 8; 169 channels -= 8; 170 } else { 171 if (channels & 4) { 172 *((float*) output) = wasm_f32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 173 vout${ABC[0:8]}${ABC[0:8]} = wasm_u64x2_shr(vout${ABC[0:8]}${ABC[0:8]}, 32); 174 output += 4; 175 } 176 uint32_t vout${ABC[0:4]} = wasm_i32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 177 if (channels & 2) { 178 *((uint16_t*) output) = (uint16_t) vout${ABC[0:4]}; 179 vout${ABC[0:4]} >>= 16; 180 output += 2; 181 } 182 if (channels & 1) { 183 *output = (${XINT8_T}) vout${ABC[0:4]}; 184 output += 1; 185 } 186 channels = 0; 187 } 188 $else: 189 if (channels & 4) { 190 *((float*) output) = wasm_f32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 191 vout${ABC[0:8]}${ABC[0:8]} = wasm_u64x2_shr(vout${ABC[0:8]}${ABC[0:8]}, 32); 192 output += 4; 193 } 194 uint32_t vout${ABC[0:4]} = wasm_i32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0); 195 if (channels & 2) { 196 *((uint16_t*) output) = (uint16_t) vout${ABC[0:4]}; 197 vout${ABC[0:4]} >>= 16; 198 output += 2; 199 } 200 if (channels & 1) { 201 *output = (${XINT8_T}) vout${ABC[0:4]}; 202 } 203 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 204 } 205} 206