• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12#include <xnnpack/spmm.h>
13
14$ARCH_SUFFIX = "_x86" if X86 else "_arm"
15
16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}_pipelined${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
17    size_t mc,
18    size_t nc,
19    const float*restrict input,
20    const float*restrict weights,
21    const int32_t*restrict widx_dmap,
22    const uint32_t*restrict nidx_nnzmap,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(mc != 0);
28  assert(mc % sizeof(float) == 0);
29  assert(nc != 0);
30
31  const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
32  const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
33  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
34  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
35    const float*restrict w = weights;
36    const int32_t* dmap = widx_dmap;
37    const uint32_t* nnzmap = nidx_nnzmap;
38    v128_t vw = wasm_v32x4_load_splat(w); w += 1;
39    intptr_t diff = *dmap++;
40    $for M in range(0, MR, 4):
41      v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
42    size_t n = nc;
43    do {
44      uint32_t nnz = *nnzmap++;
45       $for M in range(0, MR, 4):
46        v128_t vacc${ABC[M:M+4]} = vw;
47      vw = wasm_v32x4_load_splat(w); w += 1;
48
49      $if UNROLL > 1:
50        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
51          $for K in range(0, UNROLL):
52            $for M in range(0, MR, 4):
53              vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]},   vw));
54            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
55            diff = *dmap++;
56            vw = wasm_v32x4_load_splat(w); w += 1;
57            $for M in range(0, MR, 4):
58              vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
59        }
60
61      if XNN_LIKELY(nnz != 0) {
62        do {
63          $for M in range(0, MR, 4):
64            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw));
65          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
66
67          diff = *dmap++;
68          vw = wasm_v32x4_load_splat(w); w += 1;
69          $for M in range(0, MR, 4):
70            vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
71        } while (--nnz != 0);
72      }
73      $if X86:
74        $for M in range(0, MR, 4):
75          v128_t vout${ABC[M:M+4]} = wasm_v128_bitselect(vacc${ABC[M:M+4]}, vmax, wasm_f32x4_le(vacc${ABC[M:M+4]}, vmax));
76        $for M in range(0, MR, 4):
77          vout${ABC[M:M+4]} = wasm_v128_bitselect(vmin, vout${ABC[M:M+4]}, wasm_f32x4_lt(vout${ABC[M:M+4]}, vmin));
78      $else:
79        $for M in range(0, MR, 4):
80          v128_t vout${ABC[M:M+4]} = wasm_f32x4_min(vacc${ABC[M:M+4]}, vmax);
81        $for M in range(0, MR, 4):
82          vout${ABC[M:M+4]} = wasm_f32x4_max(vout${ABC[M:M+4]}, vmin);
83      wasm_v128_store(output, vout0123);
84      $for M in range(4, MR, 4):
85        wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
86      output = (float*restrict) ((uintptr_t) output + output_stride);
87    } while (--n != 0);
88    output = (float*restrict) ((uintptr_t) output - output_decrement);
89    input += ${MR};
90    mc -= ${MR} * sizeof(float);
91  }
92  if XNN_UNLIKELY(mc != 0) {
93    $for LOG2M in reversed(range((MR - 1).bit_length())):
94      $SUBMR = 1 << LOG2M
95      $if SUBMR * 2 >= MR:
96        output_decrement += ${MR - SUBMR} * sizeof(float);
97      $else:
98        output_decrement += ${SUBMR} * sizeof(float);
99      if (mc & (${SUBMR} * sizeof(float))) {
100        const float*restrict w = weights;
101        const int32_t* dmap = widx_dmap;
102        const uint32_t* nnzmap = nidx_nnzmap;
103        size_t n = nc;
104        do {
105          uint32_t nnz = *nnzmap++;
106          $if SUBMR == 1:
107            v128_t vacc0 = wasm_v32x4_load_splat(w); w += 1;
108          $elif SUBMR == 2:
109            v128_t vacc01 = wasm_v32x4_load_splat(w); w += 1;
110          $else:
111            v128_t vacc0123 = wasm_v32x4_load_splat(w); w += 1;
112          $for M in range(4, SUBMR, 4):
113            v128_t vacc${ABC[M:M+4]} = vacc0123;
114          if XNN_LIKELY(nnz != 0) {
115            do {
116              const intptr_t diff = *dmap++;
117              $if SUBMR >= 4:
118                const v128_t vi0123 = wasm_v128_load(input);
119              $elif SUBMR == 2:
120                const v128_t vi01 = wasm_v64x2_load_splat(input);
121              $elif SUBMR == 1:
122                const v128_t vi0 = wasm_v32x4_load_splat(input);
123              $for M in range(4, SUBMR, 4):
124                const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
125              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
126              const v128_t vw = wasm_v32x4_load_splat(w); w += 1;
127              $if SUBMR == 1:
128                vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw));
129              $else:
130                $for M in range(0, SUBMR, 4):
131                  vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw));
132            } while (--nnz != 0);
133          }
134          $if SUBMR == 1:
135            $if X86:
136              v128_t vout${ABC[0]} = wasm_v128_bitselect(vacc${ABC[0]}, vmax, wasm_f32x4_le(vacc${ABC[0]}, vmax));
137              vout${ABC[0]} = wasm_v128_bitselect(vmin, vout${ABC[0]}, wasm_f32x4_lt(vout${ABC[0]}, vmin));
138            $else:
139              v128_t vout${ABC[0]} = wasm_f32x4_min(vacc${ABC[0]}, vmax);
140              vout${ABC[0]} = wasm_f32x4_max(vout${ABC[0]}, vmin);
141          $else:
142            $if X86:
143              $for M in range(0, SUBMR, 4):
144                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vacc${ABC[M:min(M+4,SUBMR)]}, vmax, wasm_f32x4_le(vacc${ABC[M:min(M+4,SUBMR)]}, vmax));
145              $for M in range(0, SUBMR, 4):
146                vout${ABC[M:min(M+4,SUBMR)]} = wasm_v128_bitselect(vmin, vout${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_lt(vout${ABC[M:min(M+4,SUBMR)]}, vmin));
147            $else:
148              $for M in range(0, SUBMR, 4):
149                v128_t vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_min(vacc${ABC[M:min(M+4,SUBMR)]}, vmax);
150              $for M in range(0, SUBMR, 4):
151                vout${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_max(vout${ABC[M:min(M+4,SUBMR)]}, vmin);
152          $if SUBMR >= 4:
153            wasm_v128_store(output, vout0123);
154          $elif SUBMR == 2:
155            *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
156          $elif SUBMR == 1:
157            *output = wasm_f32x4_extract_lane(vout0, 0);
158
159          $for M in range(4, SUBMR, 4):
160            wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
161          output = (float*restrict) ((uintptr_t) output + output_stride);
162        } while (--n != 0);
163        output = (float*restrict) ((uintptr_t) output - output_decrement);
164        input += ${SUBMR};
165      }
166  }
167}
168