1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8#include <assert.h> 9 10#include <wasm_simd128.h> 11 12#include <xnnpack/spmm.h> 13 14$ARCH_SUFFIX = "_x86" if X86 else "_arm" 15 16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}${"_x" + str(UNROLL) if UNROLL > 1 else ""}( 17 size_t mc, 18 size_t nc, 19 const float*restrict input, 20 const float*restrict weights, 21 const int32_t*restrict widx_dmap, 22 const uint32_t*restrict nidx_nnzmap, 23 float*restrict output, 24 size_t output_stride, 25 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 26{ 27 assert(mc != 0); 28 assert(mc % sizeof(float) == 0); 29 assert(nc != 0); 30 31 const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min); 32 const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max); 33 size_t output_decrement = output_stride * nc - ${MR} * sizeof(float); 34 while XNN_LIKELY(mc >= ${MR} * sizeof(float)) { 35 const float*restrict w = weights; 36 const int32_t* dmap = widx_dmap; 37 const uint32_t* nnzmap = nidx_nnzmap; 38 size_t n = nc; 39 do { 40 uint32_t nnz = *nnzmap++; 41 $if UNROLL > 1: 42 v128_t vacc0123x0 = wasm_v128_load32_splat(w); 43 w += 1; 44 $for K in range(1, UNROLL): 45 v128_t vacc0123x${K} = wasm_f32x4_const_splat(0.0f); 46 $for M in range(4, MR, 4): 47 v128_t vacc${ABC[M:M+4]}x0 = vacc0123x0; 48 $for K in range(1, UNROLL): 49 v128_t vacc${ABC[M:M+4]}x${K} = wasm_f32x4_const_splat(0.0f); 50 for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) { 51 $for K in range(UNROLL): 52 const intptr_t diff${K} = dmap[${K}]; 53 dmap += ${UNROLL}; 54 $for K in range(UNROLL): 55 const v128_t vi0123x${K} = wasm_v128_load(input); 56 $for M in range(4, MR, 4): 57 const v128_t vi${ABC[M:M+4]}x${K} = wasm_v128_load(input + ${M}); 58 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K}); 59 const v128_t vw${K} = wasm_v128_load32_splat(w); 60 w += 1; 61 $for M in range(0, MR, 4): 62 vacc${ABC[M:M+4]}x${K} = wasm_f32x4_add(vacc${ABC[M:M+4]}x${K}, wasm_f32x4_mul(vi${ABC[M:M+4]}x${K}, vw${K})); 63 } 64 $for M in range(0, MR, 4): 65 v128_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0; 66 $for K in range(1, UNROLL): 67 $for M in range(0, MR, 4): 68 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K}); 69 $else: 70 v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1; 71 $for M in range(4, MR, 4): 72 v128_t vacc${ABC[M:M+4]} = vacc0123; 73 if XNN_LIKELY(nnz != 0) { 74 do { 75 const intptr_t diff = *dmap++; 76 const v128_t vi0123 = wasm_v128_load(input); 77 $for M in range(4, MR, 4): 78 const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 79 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 80 const v128_t vw = wasm_v128_load32_splat(w); w += 1; 81 $for M in range(0, MR, 4): 82 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw)); 83 } while (--nnz != 0); 84 } 85 $if X86: 86 $for M in range(0, MR, 4): 87 v128_t vout${ABC[M:M+4]} = wasm_f32x4_pmin(vmax, vacc${ABC[M:M+4]}); 88 $for M in range(0, MR, 4): 89 vout${ABC[M:M+4]} = wasm_f32x4_pmax(vmin, vout${ABC[M:M+4]}); 90 $else: 91 $for M in range(0, MR, 4): 92 v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax); 93 $for M in range(0, MR, 4): 94 vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin); 95 wasm_v128_store(output, vout0123); 96 $for M in range(4, MR, 4): 97 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 98 output = (float*restrict) ((uintptr_t) output + output_stride); 99 } while (--n != 0); 100 output = (float*restrict) ((uintptr_t) output - output_decrement); 101 input += ${MR}; 102 mc -= ${MR} * sizeof(float); 103 } 104 if XNN_UNLIKELY(mc != 0) { 105 $for LOG2M in reversed(range((MR - 1).bit_length())): 106 $SUBMR = 1 << LOG2M 107 $if SUBMR * 2 >= MR: 108 output_decrement += ${MR - SUBMR} * sizeof(float); 109 $else: 110 output_decrement += ${SUBMR} * sizeof(float); 111 if (mc & (${SUBMR} * sizeof(float))) { 112 const float*restrict w = weights; 113 const int32_t* dmap = widx_dmap; 114 const uint32_t* nnzmap = nidx_nnzmap; 115 size_t n = nc; 116 do { 117 uint32_t nnz = *nnzmap++; 118 $if SUBMR == 1: 119 v128_t vacc0 = wasm_v128_load32_splat(w); w += 1; 120 $elif SUBMR == 2: 121 v128_t vacc01 = wasm_v128_load32_splat(w); w += 1; 122 $else: 123 v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1; 124 $for M in range(4, SUBMR, 4): 125 v128_t vacc${ABC[M:M+4]} = vacc0123; 126 if XNN_LIKELY(nnz != 0) { 127 do { 128 const intptr_t diff = *dmap++; 129 $if SUBMR >= 4: 130 const v128_t vi0123 = wasm_v128_load(input); 131 $elif SUBMR == 2: 132 const v128_t vi01 = wasm_v128_load64_splat(input); 133 $elif SUBMR == 1: 134 const v128_t vi0 = wasm_v128_load32_splat(input); 135 $for M in range(4, SUBMR, 4): 136 const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 137 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 138 const v128_t vw = wasm_v128_load32_splat(w); w += 1; 139 $if SUBMR == 1: 140 vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw)); 141 $else: 142 $for M in range(0, SUBMR, 4): 143 vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw)); 144 } while (--nnz != 0); 145 } 146 $if SUBMR == 1: 147 $if X86: 148 v128_t vout${ABC[0]} = wasm_f32x4_pmin(vmax, vacc${ABC[0]}); 149 vout${ABC[0]} = wasm_f32x4_pmax(vmin, vout${ABC[0]}); 150 $else: 151 v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax); 152 vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin); 153 $else: 154 $if X86: 155 $for M in range(0, SUBMR, 4): 156 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_pmin(vmax, vacc${ABC[M:min(M+4,SUBMR)]}); 157 $for M in range(0, SUBMR, 4): 158 vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_pmax(vmin, vout${ABC[M:min(M+4,SUBMR)]}); 159 $else: 160 $for M in range(0, SUBMR, 4): 161 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax); 162 $for M in range(0, SUBMR, 4): 163 vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin); 164 $if SUBMR >= 4: 165 wasm_v128_store(output, vout0123); 166 $elif SUBMR == 2: 167 *((double*) output) = wasm_f64x2_extract_lane(vout01, 0); 168 $elif SUBMR == 1: 169 *output = wasm_f32x4_extract_lane(vout0, 0); 170 171 $for M in range(4, SUBMR, 4): 172 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 173 output = (float*restrict) ((uintptr_t) output + output_stride); 174 } while (--n != 0); 175 output = (float*restrict) ((uintptr_t) output - output_decrement); 176 input += ${SUBMR}; 177 } 178 } 179} 180