1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8#include <assert.h> 9 10#include <wasm_simd128.h> 11 12#include <xnnpack/spmm.h> 13 14$ARCH_SUFFIX = "_x86" if X86 else "_arm" 15 16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}${"_x" + str(UNROLL) if UNROLL > 1 else ""}( 17 size_t mc, 18 size_t nc, 19 const float*restrict input, 20 const float*restrict weights, 21 const int32_t*restrict widx_dmap, 22 const uint32_t*restrict nidx_nnzmap, 23 float*restrict output, 24 size_t output_stride, 25 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 26{ 27 assert(mc != 0); 28 assert(mc % sizeof(float) == 0); 29 assert(nc != 0); 30 31 const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min); 32 const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max); 33 $if UNROLL > 1: 34 const v128_t vzero = wasm_f64x2_splat(0.0); 35 size_t output_decrement = output_stride * nc - ${MR} * sizeof(float); 36 while XNN_LIKELY(mc >= ${MR} * sizeof(float)) { 37 const float*restrict w = weights; 38 const int32_t* dmap = widx_dmap; 39 const uint32_t* nnzmap = nidx_nnzmap; 40 size_t n = nc; 41 do { 42 uint32_t nnz = *nnzmap++; 43 $if UNROLL > 1: 44 v128_t vacc0123x0 = wasm_v32x4_load_splat(w); 45 w += 1; 46 $for K in range(1, UNROLL): 47 v128_t vacc0123x${K} = vzero; 48 $for M in range(4, MR, 4): 49 v128_t vacc${ABC[M:M+4]}x0 = vacc0123x0; 50 $for K in range(1, UNROLL): 51 v128_t vacc${ABC[M:M+4]}x${K} = vzero; 52 for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) { 53 $for K in range(UNROLL): 54 const intptr_t diff${K} = dmap[${K}]; 55 dmap += ${UNROLL}; 56 $for K in range(UNROLL): 57 const v128_t vi0123x${K} = wasm_v128_load(input); 58 $for M in range(4, MR, 4): 59 const v128_t vi${ABC[M:M+4]}x${K} = wasm_v128_load(input + ${M}); 60 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K}); 61 const v128_t vw${K} = wasm_v32x4_load_splat(w); 62 w += 1; 63 $for M in range(0, MR, 4): 64 vacc${ABC[M:M+4]}x${K} = wasm_f32x4_add(vacc${ABC[M:M+4]}x${K}, wasm_f32x4_mul(vi${ABC[M:M+4]}x${K}, vw${K})); 65 } 66 $for M in range(0, MR, 4): 67 v128_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0; 68 $for K in range(1, UNROLL): 69 $for M in range(0, MR, 4): 70 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K}); 71 $else: 72 v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1; 73 $for M in range(4, MR, 4): 74 v128_t vacc${ABC[M:M+4]} = vacc0123; 75 if XNN_LIKELY(nnz != 0) { 76 do { 77 const intptr_t diff = *dmap++; 78 const v128_t vi0123 = wasm_v128_load(input); 79 $for M in range(4, MR, 4): 80 const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 81 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 82 const v128_t vw = wasm_v32x4_load_splat(w); w += 1; 83 $for M in range(0, MR, 4): 84 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw)); 85 } while (--nnz != 0); 86 } 87 $if X86: 88 $for M in range(0, MR, 4): 89 v128_t vout${ABC[M:M+4]} = wasm_v128_bitselect(vacc${ABC[M:M+4]}, vmax, wasm_f32x4_le(vacc${ABC[M:M+4]}, vmax)); 90 $for M in range(0, MR, 4): 91 vout${ABC[M:M+4]} = wasm_v128_bitselect(vmin, vout${ABC[M:M+4]}, wasm_f32x4_lt(vout${ABC[M:M+4]}, vmin)); 92 $else: 93 $for M in range(0, MR, 4): 94 v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax); 95 $for M in range(0, MR, 4): 96 vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin); 97 wasm_v128_store(output, vout0123); 98 $for M in range(4, MR, 4): 99 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 100 output = (float*restrict) ((uintptr_t) output + output_stride); 101 } while (--n != 0); 102 output = (float*restrict) ((uintptr_t) output - output_decrement); 103 input += ${MR}; 104 mc -= ${MR} * sizeof(float); 105 } 106 if XNN_UNLIKELY(mc != 0) { 107 $for LOG2M in reversed(range((MR - 1).bit_length())): 108 $SUBMR = 1 << LOG2M 109 $if SUBMR * 2 >= MR: 110 output_decrement += ${MR - SUBMR} * sizeof(float); 111 $else: 112 output_decrement += ${SUBMR} * sizeof(float); 113 if (mc & (${SUBMR} * sizeof(float))) { 114 const float*restrict w = weights; 115 const int32_t* dmap = widx_dmap; 116 const uint32_t* nnzmap = nidx_nnzmap; 117 size_t n = nc; 118 do { 119 uint32_t nnz = *nnzmap++; 120 $if SUBMR == 1: 121 v128_t vacc0 = wasm_v32x4_load_splat(w); w += 1; 122 $elif SUBMR == 2: 123 v128_t vacc01 = wasm_v32x4_load_splat(w); w += 1; 124 $else: 125 v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1; 126 $for M in range(4, SUBMR, 4): 127 v128_t vacc${ABC[M:M+4]} = vacc0123; 128 if XNN_LIKELY(nnz != 0) { 129 do { 130 const intptr_t diff = *dmap++; 131 $if SUBMR >= 4: 132 const v128_t vi0123 = wasm_v128_load(input); 133 $elif SUBMR == 2: 134 const v128_t vi01 = wasm_v64x2_load_splat(input); 135 $elif SUBMR == 1: 136 const v128_t vi0 = wasm_v32x4_load_splat(input); 137 $for M in range(4, SUBMR, 4): 138 const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 139 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 140 const v128_t vw = wasm_v32x4_load_splat(w); w += 1; 141 $if SUBMR == 1: 142 vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw)); 143 $else: 144 $for M in range(0, SUBMR, 4): 145 vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw)); 146 } while (--nnz != 0); 147 } 148 $if SUBMR == 1: 149 $if X86: 150 v128_t vout${ABC[0]} = wasm_v128_bitselect(vacc${ABC[0]}, vmax, wasm_f32x4_le(vacc${ABC[0]}, vmax)); 151 vout${ABC[0]} = wasm_v128_bitselect(vmin, vout${ABC[0]}, wasm_f32x4_lt(vout${ABC[0]}, vmin)); 152 $else: 153 v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax); 154 vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin); 155 $else: 156 $if X86: 157 $for M in range(0, SUBMR, 4): 158 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vacc${ABC[M:min(M+4,SUBMR)]}, vmax, wasm_f32x4_le(vacc${ABC[M:min(M+4,SUBMR)]}, vmax)); 159 $for M in range(0, SUBMR, 4): 160 vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vmin, vout${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_lt(vout${ABC[M:min(M+4,SUBMR)]}, vmin)); 161 $else: 162 $for M in range(0, SUBMR, 4): 163 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax); 164 $for M in range(0, SUBMR, 4): 165 vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin); 166 $if SUBMR >= 4: 167 wasm_v128_store(output, vout0123); 168 $elif SUBMR == 2: 169 *((double*) output) = wasm_f64x2_extract_lane(vout01, 0); 170 $elif SUBMR == 1: 171 *output = wasm_f32x4_extract_lane(vout0, 0); 172 173 $for M in range(4, SUBMR, 4): 174 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 175 output = (float*restrict) ((uintptr_t) output + output_stride); 176 } while (--n != 0); 177 output = (float*restrict) ((uintptr_t) output - output_decrement); 178 input += ${SUBMR}; 179 } 180 } 181} 182