• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32"
9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
10#include <assert.h>
11
12#include <arm_neon.h>
13
14#include <xnnpack/spmm.h>
15
16
17void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_pipelined(
18    size_t mc,
19    size_t nc,
20    const float*restrict input,
21    const float*restrict weights,
22    const int32_t*restrict widx_dmap,
23    const uint32_t*restrict nidx_nnzmap,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27{
28  assert(mc != 0);
29  assert(mc % sizeof(float) == 0);
30  assert(nc != 0);
31
32  const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
33  const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
34  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
35  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
36    const float*restrict w = weights;
37    const int32_t* dmap = widx_dmap;
38    const uint32_t* nnzmap = nidx_nnzmap;
39    float32x4_t vw = vld1q_dup_f32(w); w += 1;
40    intptr_t diff = *dmap++;
41    float32x4_t vi0123 = vld1q_f32(input);
42    $for M in range(4, MR, 4):
43      float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
44    size_t n = nc;
45    do {
46      uint32_t nnz = *nnzmap++;
47      $for M in range(0, MR, 4):
48        float32x4_t vacc${ABC[M:M+4]} = vw;
49      vw = vld1q_dup_f32(w); w += 1;
50      if XNN_LIKELY(nnz != 0) {
51        do {
52          $for M in range(0, MR, 4):
53            vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw);
54          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
55          $for M in range(0, MR, 16):
56            __builtin_prefetch(input + ${M+16});
57          diff = *dmap++;
58          vw = vld1q_dup_f32(w); w += 1;
59          __builtin_prefetch(w + 32);
60          vi0123 = vld1q_f32(input);
61          $for M in range(4, MR, 4):
62            vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
63        } while (--nnz != 0);
64      }
65      $for M in range(0, MR, 4):
66        float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
67      $for M in range(0, MR, 4):
68        vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
69      vst1q_f32(output, vout0123);
70      $for M in range(4, MR, 4):
71        vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
72      output = (float*restrict) ((uintptr_t) output + output_stride);
73    } while (--n != 0);
74    output = (float*restrict) ((uintptr_t) output - output_decrement);
75    input += ${MR};
76    mc -= ${MR} * sizeof(float);
77  }
78  if XNN_UNLIKELY(mc != 0) {
79    $for LOG2M in reversed(range((MR - 1).bit_length())):
80      $SUBMR = 1 << LOG2M
81      $if SUBMR * 2 >= MR:
82        output_decrement += ${MR - SUBMR} * sizeof(float);
83      $else:
84        output_decrement += ${SUBMR} * sizeof(float);
85      if (mc & (${SUBMR} * sizeof(float))) {
86        const float*restrict w = weights;
87        const int32_t* dmap = widx_dmap;
88        const uint32_t* nnzmap = nidx_nnzmap;
89        size_t n = nc;
90        do {
91          uint32_t nnz = *nnzmap++;
92          $if SUBMR <= 2:
93            float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
94          $else:
95            float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
96          $for M in range(4, SUBMR, 4):
97            float32x4_t vacc${ABC[M:M+4]} = vacc0123;
98          if XNN_LIKELY(nnz != 0) {
99            do {
100              const intptr_t diff = *dmap++;
101              $if SUBMR == 1:
102                const float32x2_t vi0 = vld1_dup_f32(input);
103              $elif SUBMR == 2:
104                const float32x2_t vi01 = vld1_f32(input);
105              $else:
106                const float32x4_t vi0123 = vld1q_f32(input);
107              $for M in range(4, SUBMR, 4):
108                const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
109              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
110              $for M in range(0, MR, 16):
111                __builtin_prefetch(input + ${M+16});
112              $if SUBMR <= 2:
113                const float32x2_t vb = vld1_dup_f32(w); w += 1;
114              $else:
115                const float32x4_t vb = vld1q_dup_f32(w); w += 1;
116              __builtin_prefetch(w + 32);
117              $if SUBMR <= 2:
118                vacc${ABC[0:SUBMR]} = ${VMULADD_F32}(vacc${ABC[0:SUBMR]}, vi${ABC[0:SUBMR]}, vb);
119              $else:
120                $for M in range(0, SUBMR, 4):
121                  vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vb);
122            } while (--nnz != 0);
123          }
124          $if SUBMR <= 2:
125            float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
126            vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
127            $if SUBMR == 1:
128              vst1_lane_f32(output, vout${ABC[0]}, 0);
129            $else:
130              vst1_f32(output, vout${ABC[0:SUBMR]});
131          $else:
132            $for M in range(0, SUBMR, 4):
133              float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
134            $for M in range(0, SUBMR, 4):
135              vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
136            vst1q_f32(output, vout0123);
137            $for M in range(4, SUBMR, 4):
138              vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
139          output = (float*restrict) ((uintptr_t) output + output_stride);
140        } while (--n != 0);
141        output = (float*restrict) ((uintptr_t) output - output_decrement);
142        input += ${SUBMR};
143      }
144  }
145}
146