• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12#include <xnnpack/spmm.h>
13
14$ARCH_SUFFIX = "_x86" if X86 else "_arm"
15
16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
17    size_t mc,
18    size_t nc,
19    const float*restrict input,
20    const float*restrict weights,
21    const int32_t*restrict widx_dmap,
22    const uint32_t*restrict nidx_nnzmap,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(mc != 0);
28  assert(mc % sizeof(float) == 0);
29  assert(nc != 0);
30
31  const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
32  const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
33  $if UNROLL > 1:
34    const v128_t vzero = wasm_f64x2_splat(0.0);
35  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
36  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
37    const float*restrict w = weights;
38    const int32_t* dmap = widx_dmap;
39    const uint32_t* nnzmap = nidx_nnzmap;
40    size_t n = nc;
41    do {
42      uint32_t nnz = *nnzmap++;
43      $if UNROLL > 1:
44        v128_t vacc0123x0 = wasm_v32x4_load_splat(w);
45        w += 1;
46        $for K in range(1, UNROLL):
47          v128_t vacc0123x${K} = vzero;
48        $for M in range(4, MR, 4):
49          v128_t vacc${ABC[M:M+4]}x0 = vacc0123x0;
50          $for K in range(1, UNROLL):
51            v128_t vacc${ABC[M:M+4]}x${K} = vzero;
52        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
53          $for K in range(UNROLL):
54            const intptr_t diff${K} = dmap[${K}];
55          dmap += ${UNROLL};
56          $for K in range(UNROLL):
57            const v128_t vi0123x${K} = wasm_v128_load(input);
58            $for M in range(4, MR, 4):
59              const v128_t vi${ABC[M:M+4]}x${K} = wasm_v128_load(input + ${M});
60            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K});
61            const v128_t vw${K} = wasm_v32x4_load_splat(w);
62            w += 1;
63            $for M in range(0, MR, 4):
64              vacc${ABC[M:M+4]}x${K} = wasm_f32x4_add(vacc${ABC[M:M+4]}x${K}, wasm_f32x4_mul(vi${ABC[M:M+4]}x${K}, vw${K}));
65        }
66        $for M in range(0, MR, 4):
67          v128_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
68        $for K in range(1, UNROLL):
69          $for M in range(0, MR, 4):
70            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
71      $else:
72        v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1;
73        $for M in range(4, MR, 4):
74          v128_t vacc${ABC[M:M+4]} = vacc0123;
75      if XNN_LIKELY(nnz != 0) {
76        do {
77          const intptr_t diff = *dmap++;
78          const v128_t vi0123 = wasm_v128_load(input);
79          $for M in range(4, MR, 4):
80            const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
81          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
82          const v128_t vw = wasm_v32x4_load_splat(w); w += 1;
83          $for M in range(0, MR, 4):
84            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw));
85        } while (--nnz != 0);
86      }
87      $if X86:
88        $for M in range(0, MR, 4):
89          v128_t vout${ABC[M:M+4]} = wasm_v128_bitselect(vacc${ABC[M:M+4]}, vmax, wasm_f32x4_le(vacc${ABC[M:M+4]}, vmax));
90        $for M in range(0, MR, 4):
91          vout${ABC[M:M+4]} = wasm_v128_bitselect(vmin, vout${ABC[M:M+4]}, wasm_f32x4_lt(vout${ABC[M:M+4]}, vmin));
92      $else:
93        $for M in range(0, MR, 4):
94          v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax);
95        $for M in range(0, MR, 4):
96          vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin);
97      wasm_v128_store(output, vout0123);
98      $for M in range(4, MR, 4):
99        wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
100      output = (float*restrict) ((uintptr_t) output + output_stride);
101    } while (--n != 0);
102    output = (float*restrict) ((uintptr_t) output - output_decrement);
103    input += ${MR};
104    mc -= ${MR} * sizeof(float);
105  }
106  if XNN_UNLIKELY(mc != 0) {
107    $for LOG2M in reversed(range((MR - 1).bit_length())):
108      $SUBMR = 1 << LOG2M
109      $if SUBMR * 2 >= MR:
110        output_decrement += ${MR - SUBMR} * sizeof(float);
111      $else:
112        output_decrement += ${SUBMR} * sizeof(float);
113      if (mc & (${SUBMR} * sizeof(float))) {
114        const float*restrict w = weights;
115        const int32_t* dmap = widx_dmap;
116        const uint32_t* nnzmap = nidx_nnzmap;
117        size_t n = nc;
118        do {
119          uint32_t nnz = *nnzmap++;
120          $if SUBMR == 1:
121            v128_t vacc0 = wasm_v32x4_load_splat(w); w += 1;
122          $elif SUBMR == 2:
123            v128_t vacc01 = wasm_v32x4_load_splat(w); w += 1;
124          $else:
125            v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1;
126          $for M in range(4, SUBMR, 4):
127            v128_t vacc${ABC[M:M+4]} = vacc0123;
128          if XNN_LIKELY(nnz != 0) {
129            do {
130              const intptr_t diff = *dmap++;
131              $if SUBMR >= 4:
132                const v128_t vi0123 = wasm_v128_load(input);
133              $elif SUBMR == 2:
134                const v128_t vi01 = wasm_v64x2_load_splat(input);
135              $elif SUBMR == 1:
136                const v128_t vi0 = wasm_v32x4_load_splat(input);
137              $for M in range(4, SUBMR, 4):
138                const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
139              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
140              const v128_t vw = wasm_v32x4_load_splat(w); w += 1;
141              $if SUBMR == 1:
142                vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw));
143              $else:
144                $for M in range(0, SUBMR, 4):
145                  vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw));
146            } while (--nnz != 0);
147          }
148          $if SUBMR == 1:
149            $if X86:
150              v128_t vout${ABC[0]} = wasm_v128_bitselect(vacc${ABC[0]}, vmax, wasm_f32x4_le(vacc${ABC[0]}, vmax));
151              vout${ABC[0]} = wasm_v128_bitselect(vmin, vout${ABC[0]}, wasm_f32x4_lt(vout${ABC[0]}, vmin));
152            $else:
153              v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax);
154              vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin);
155          $else:
156            $if X86:
157              $for M in range(0, SUBMR, 4):
158                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vacc${ABC[M:min(M+4,SUBMR)]}, vmax, wasm_f32x4_le(vacc${ABC[M:min(M+4,SUBMR)]}, vmax));
159              $for M in range(0, SUBMR, 4):
160                vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vmin, vout${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_lt(vout${ABC[M:min(M+4,SUBMR)]}, vmin));
161            $else:
162              $for M in range(0, SUBMR, 4):
163                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax);
164              $for M in range(0, SUBMR, 4):
165                vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin);
166          $if SUBMR >= 4:
167            wasm_v128_store(output, vout0123);
168          $elif SUBMR == 2:
169            *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
170          $elif SUBMR == 1:
171            *output = wasm_f32x4_extract_lane(vout0, 0);
172
173          $for M in range(4, SUBMR, 4):
174            wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
175          output = (float*restrict) ((uintptr_t) output + output_stride);
176        } while (--n != 0);
177        output = (float*restrict) ((uintptr_t) output - output_decrement);
178        input += ${SUBMR};
179      }
180  }
181}
182