• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12#include <xnnpack/spmm.h>
13
14$ARCH_SUFFIX = "_x86" if X86 else "_arm"
15
16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
17    size_t mc,
18    size_t nc,
19    const float*restrict input,
20    const float*restrict weights,
21    const int32_t*restrict widx_dmap,
22    const uint32_t*restrict nidx_nnzmap,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(mc != 0);
28  assert(mc % sizeof(float) == 0);
29  assert(nc != 0);
30
31  const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
32  const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
33  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
34  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
35    const float*restrict w = weights;
36    const int32_t* dmap = widx_dmap;
37    const uint32_t* nnzmap = nidx_nnzmap;
38    size_t n = nc;
39    do {
40      uint32_t nnz = *nnzmap++;
41      $if UNROLL > 1:
42        v128_t vacc0123x0 = wasm_v128_load32_splat(w);
43        w += 1;
44        $for K in range(1, UNROLL):
45          v128_t vacc0123x${K} = wasm_f32x4_const_splat(0.0f);
46        $for M in range(4, MR, 4):
47          v128_t vacc${ABC[M:M+4]}x0 = vacc0123x0;
48          $for K in range(1, UNROLL):
49            v128_t vacc${ABC[M:M+4]}x${K} = wasm_f32x4_const_splat(0.0f);
50        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
51          $for K in range(UNROLL):
52            const intptr_t diff${K} = dmap[${K}];
53          dmap += ${UNROLL};
54          $for K in range(UNROLL):
55            const v128_t vi0123x${K} = wasm_v128_load(input);
56            $for M in range(4, MR, 4):
57              const v128_t vi${ABC[M:M+4]}x${K} = wasm_v128_load(input + ${M});
58            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K});
59            const v128_t vw${K} = wasm_v128_load32_splat(w);
60            w += 1;
61            $for M in range(0, MR, 4):
62              vacc${ABC[M:M+4]}x${K} = wasm_f32x4_add(vacc${ABC[M:M+4]}x${K}, wasm_f32x4_mul(vi${ABC[M:M+4]}x${K}, vw${K}));
63        }
64        $for M in range(0, MR, 4):
65          v128_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
66        $for K in range(1, UNROLL):
67          $for M in range(0, MR, 4):
68            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
69      $else:
70        v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
71        $for M in range(4, MR, 4):
72          v128_t vacc${ABC[M:M+4]} = vacc0123;
73      if XNN_LIKELY(nnz != 0) {
74        do {
75          const intptr_t diff = *dmap++;
76          const v128_t vi0123 = wasm_v128_load(input);
77          $for M in range(4, MR, 4):
78            const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
79          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
80          const v128_t vw = wasm_v128_load32_splat(w); w += 1;
81          $for M in range(0, MR, 4):
82            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw));
83        } while (--nnz != 0);
84      }
85      $if X86:
86        $for M in range(0, MR, 4):
87          v128_t vout${ABC[M:M+4]} = wasm_f32x4_pmin(vmax, vacc${ABC[M:M+4]});
88        $for M in range(0, MR, 4):
89          vout${ABC[M:M+4]} = wasm_f32x4_pmax(vmin, vout${ABC[M:M+4]});
90      $else:
91        $for M in range(0, MR, 4):
92          v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax);
93        $for M in range(0, MR, 4):
94          vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin);
95      wasm_v128_store(output, vout0123);
96      $for M in range(4, MR, 4):
97        wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
98      output = (float*restrict) ((uintptr_t) output + output_stride);
99    } while (--n != 0);
100    output = (float*restrict) ((uintptr_t) output - output_decrement);
101    input += ${MR};
102    mc -= ${MR} * sizeof(float);
103  }
104  if XNN_UNLIKELY(mc != 0) {
105    $for LOG2M in reversed(range((MR - 1).bit_length())):
106      $SUBMR = 1 << LOG2M
107      $if SUBMR * 2 >= MR:
108        output_decrement += ${MR - SUBMR} * sizeof(float);
109      $else:
110        output_decrement += ${SUBMR} * sizeof(float);
111      if (mc & (${SUBMR} * sizeof(float))) {
112        const float*restrict w = weights;
113        const int32_t* dmap = widx_dmap;
114        const uint32_t* nnzmap = nidx_nnzmap;
115        size_t n = nc;
116        do {
117          uint32_t nnz = *nnzmap++;
118          $if SUBMR == 1:
119            v128_t vacc0 = wasm_v128_load32_splat(w); w += 1;
120          $elif SUBMR == 2:
121            v128_t vacc01 = wasm_v128_load32_splat(w); w += 1;
122          $else:
123            v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
124          $for M in range(4, SUBMR, 4):
125            v128_t vacc${ABC[M:M+4]} = vacc0123;
126          if XNN_LIKELY(nnz != 0) {
127            do {
128              const intptr_t diff = *dmap++;
129              $if SUBMR >= 4:
130                const v128_t vi0123 = wasm_v128_load(input);
131              $elif SUBMR == 2:
132                const v128_t vi01 = wasm_v128_load64_splat(input);
133              $elif SUBMR == 1:
134                const v128_t vi0 = wasm_v128_load32_splat(input);
135              $for M in range(4, SUBMR, 4):
136                const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
137              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
138              const v128_t vw = wasm_v128_load32_splat(w); w += 1;
139              $if SUBMR == 1:
140                vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw));
141              $else:
142                $for M in range(0, SUBMR, 4):
143                  vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw));
144            } while (--nnz != 0);
145          }
146          $if SUBMR == 1:
147            $if X86:
148              v128_t vout${ABC[0]} = wasm_f32x4_pmin(vmax, vacc${ABC[0]});
149              vout${ABC[0]} = wasm_f32x4_pmax(vmin, vout${ABC[0]});
150            $else:
151              v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax);
152              vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin);
153          $else:
154            $if X86:
155              $for M in range(0, SUBMR, 4):
156                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_pmin(vmax, vacc${ABC[M:min(M+4,SUBMR)]});
157              $for M in range(0, SUBMR, 4):
158                vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_pmax(vmin, vout${ABC[M:min(M+4,SUBMR)]});
159            $else:
160              $for M in range(0, SUBMR, 4):
161                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax);
162              $for M in range(0, SUBMR, 4):
163                vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin);
164          $if SUBMR >= 4:
165            wasm_v128_store(output, vout0123);
166          $elif SUBMR == 2:
167            *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
168          $elif SUBMR == 1:
169            *output = wasm_f32x4_extract_lane(vout0, 0);
170
171          $for M in range(4, SUBMR, 4):
172            wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
173          output = (float*restrict) ((uintptr_t) output + output_stride);
174        } while (--n != 0);
175        output = (float*restrict) ((uintptr_t) output - output_decrement);
176        input += ${SUBMR};
177      }
178  }
179}
180