1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8#include <assert.h> 9 10#include <wasm_simd128.h> 11 12#include <xnnpack/spmm.h> 13 14$ARCH_SUFFIX = "_x86" if X86 else "_arm" 15 16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}_pipelined${"_x" + str(UNROLL) if UNROLL > 1 else ""}( 17 size_t mc, 18 size_t nc, 19 const float*restrict input, 20 const float*restrict weights, 21 const int32_t*restrict widx_dmap, 22 const uint32_t*restrict nidx_nnzmap, 23 float*restrict output, 24 size_t output_stride, 25 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 26{ 27 assert(mc != 0); 28 assert(mc % sizeof(float) == 0); 29 assert(nc != 0); 30 31 const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min); 32 const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max); 33 size_t output_decrement = output_stride * nc - ${MR} * sizeof(float); 34 while XNN_LIKELY(mc >= ${MR} * sizeof(float)) { 35 const float*restrict w = weights; 36 const int32_t* dmap = widx_dmap; 37 const uint32_t* nnzmap = nidx_nnzmap; 38 v128_t vw = wasm_v32x4_load_splat(w); w += 1; 39 intptr_t diff = *dmap++; 40 $for M in range(0, MR, 4): 41 v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 42 size_t n = nc; 43 do { 44 uint32_t nnz = *nnzmap++; 45 $for M in range(0, MR, 4): 46 v128_t vacc${ABC[M:M+4]} = vw; 47 vw = wasm_v32x4_load_splat(w); w += 1; 48 49 $if UNROLL > 1: 50 for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) { 51 $for K in range(0, UNROLL): 52 $for M in range(0, MR, 4): 53 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw)); 54 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 55 diff = *dmap++; 56 vw = wasm_v32x4_load_splat(w); w += 1; 57 $for M in range(0, MR, 4): 58 vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 59 } 60 61 if XNN_LIKELY(nnz != 0) { 62 do { 63 $for M in range(0, MR, 4): 64 vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw)); 65 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 66 67 diff = *dmap++; 68 vw = wasm_v32x4_load_splat(w); w += 1; 69 $for M in range(0, MR, 4): 70 vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 71 } while (--nnz != 0); 72 } 73 $if X86: 74 $for M in range(0, MR, 4): 75 v128_t vout${ABC[M:M+4]} = wasm_v128_bitselect(vacc${ABC[M:M+4]}, vmax, wasm_f32x4_le(vacc${ABC[M:M+4]}, vmax)); 76 $for M in range(0, MR, 4): 77 vout${ABC[M:M+4]} = wasm_v128_bitselect(vmin, vout${ABC[M:M+4]}, wasm_f32x4_lt(vout${ABC[M:M+4]}, vmin)); 78 $else: 79 $for M in range(0, MR, 4): 80 v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax); 81 $for M in range(0, MR, 4): 82 vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin); 83 wasm_v128_store(output, vout0123); 84 $for M in range(4, MR, 4): 85 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 86 output = (float*restrict) ((uintptr_t) output + output_stride); 87 } while (--n != 0); 88 output = (float*restrict) ((uintptr_t) output - output_decrement); 89 input += ${MR}; 90 mc -= ${MR} * sizeof(float); 91 } 92 if XNN_UNLIKELY(mc != 0) { 93 $for LOG2M in reversed(range((MR - 1).bit_length())): 94 $SUBMR = 1 << LOG2M 95 $if SUBMR * 2 >= MR: 96 output_decrement += ${MR - SUBMR} * sizeof(float); 97 $else: 98 output_decrement += ${SUBMR} * sizeof(float); 99 if (mc & (${SUBMR} * sizeof(float))) { 100 const float*restrict w = weights; 101 const int32_t* dmap = widx_dmap; 102 const uint32_t* nnzmap = nidx_nnzmap; 103 size_t n = nc; 104 do { 105 uint32_t nnz = *nnzmap++; 106 $if SUBMR == 1: 107 v128_t vacc0 = wasm_v32x4_load_splat(w); w += 1; 108 $elif SUBMR == 2: 109 v128_t vacc01 = wasm_v32x4_load_splat(w); w += 1; 110 $else: 111 v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1; 112 $for M in range(4, SUBMR, 4): 113 v128_t vacc${ABC[M:M+4]} = vacc0123; 114 if XNN_LIKELY(nnz != 0) { 115 do { 116 const intptr_t diff = *dmap++; 117 $if SUBMR >= 4: 118 const v128_t vi0123 = wasm_v128_load(input); 119 $elif SUBMR == 2: 120 const v128_t vi01 = wasm_v64x2_load_splat(input); 121 $elif SUBMR == 1: 122 const v128_t vi0 = wasm_v32x4_load_splat(input); 123 $for M in range(4, SUBMR, 4): 124 const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M}); 125 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 126 const v128_t vw = wasm_v32x4_load_splat(w); w += 1; 127 $if SUBMR == 1: 128 vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw)); 129 $else: 130 $for M in range(0, SUBMR, 4): 131 vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw)); 132 } while (--nnz != 0); 133 } 134 $if SUBMR == 1: 135 $if X86: 136 v128_t vout${ABC[0]} = wasm_v128_bitselect(vacc${ABC[0]}, vmax, wasm_f32x4_le(vacc${ABC[0]}, vmax)); 137 vout${ABC[0]} = wasm_v128_bitselect(vmin, vout${ABC[0]}, wasm_f32x4_lt(vout${ABC[0]}, vmin)); 138 $else: 139 v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax); 140 vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin); 141 $else: 142 $if X86: 143 $for M in range(0, SUBMR, 4): 144 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vacc${ABC[M:min(M+4,SUBMR)]}, vmax, wasm_f32x4_le(vacc${ABC[M:min(M+4,SUBMR)]}, vmax)); 145 $for M in range(0, SUBMR, 4): 146 vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vmin, vout${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_lt(vout${ABC[M:min(M+4,SUBMR)]}, vmin)); 147 $else: 148 $for M in range(0, SUBMR, 4): 149 v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax); 150 $for M in range(0, SUBMR, 4): 151 vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin); 152 $if SUBMR >= 4: 153 wasm_v128_store(output, vout0123); 154 $elif SUBMR == 2: 155 *((double*) output) = wasm_f64x2_extract_lane(vout01, 0); 156 $elif SUBMR == 1: 157 *output = wasm_f32x4_extract_lane(vout0, 0); 158 159 $for M in range(4, SUBMR, 4): 160 wasm_v128_store(output + ${M}, vout${ABC[M:M+4]}); 161 output = (float*restrict) ((uintptr_t) output + output_stride); 162 } while (--n != 0); 163 output = (float*restrict) ((uintptr_t) output - output_decrement); 164 input += ${SUBMR}; 165 } 166 } 167} 168