• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_ukernel_8x1__neonfma(uint32_t m,uint32_t n,const float * restrict a,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict c,const union xnn_f32_output_params params[restrict static1])17 void xnn_f32_spmm_ukernel_8x1__neonfma(
18     uint32_t m,
19     uint32_t n,
20     const float*restrict a,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict c,
25     const union xnn_f32_output_params params[restrict static 1])
26 {
27   assert(m != 0);
28 
29   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
30   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
31   size_t i = m;
32   while XNN_LIKELY(i >= 8) {
33     const float*restrict w = weights;
34     const int32_t* dmap = widx_dmap;
35     const uint32_t* nnzmap = nidx_nnzmap;
36     size_t j = n;
37     do {
38       uint32_t nnz = *nnzmap++;
39       float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
40       float32x4_t vacc4567 = vacc0123;
41       if XNN_LIKELY(nnz != 0) {
42         do {
43           const intptr_t diff = *dmap++;
44           const float32x4_t va0123 = vld1q_f32(a);
45           const float32x4_t va4567 = vld1q_f32(a + 4);
46           a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
47           const float32x4_t vb = vld1q_dup_f32(w); w += 1;
48           vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
49           vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
50         } while (--nnz != 0);
51       }
52       float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
53       float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
54       vout0123 = vmaxq_f32(vout0123, vmin);
55       vout4567 = vmaxq_f32(vout4567, vmin);
56       vst1q_f32(c, vout0123);
57       vst1q_f32(c + 4, vout4567);
58       c += m;
59     } while (--j != 0);
60     c -= m * n;
61     c += 8;
62     a += 8;
63     i -= 8;
64   }
65   if XNN_UNLIKELY(i != 0) {
66     if (i & 4) {
67       const float*restrict w = weights;
68       const int32_t* dmap = widx_dmap;
69       const uint32_t* nnzmap = nidx_nnzmap;
70       size_t j = n;
71       do {
72         uint32_t nnz = *nnzmap++;
73         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
74         if XNN_LIKELY(nnz != 0) {
75           do {
76             const intptr_t diff = *dmap++;
77             const float32x4_t va0123 = vld1q_f32(a);
78             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
79             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
80             vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
81           } while (--nnz != 0);
82         }
83         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
84         vout0123 = vmaxq_f32(vout0123, vmin);
85         vst1q_f32(c, vout0123);
86         c += m;
87       } while (--j != 0);
88       c -= m * n;
89       c += 4;
90       a += 4;
91     }
92     if (i & 2) {
93       const float*restrict w = weights;
94       const int32_t* dmap = widx_dmap;
95       const uint32_t* nnzmap = nidx_nnzmap;
96       size_t j = n;
97       do {
98         uint32_t nnz = *nnzmap++;
99         float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
100         if XNN_LIKELY(nnz != 0) {
101           do {
102             const intptr_t diff = *dmap++;
103             const float32x2_t va01 = vld1_f32(a);
104             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
105             const float32x2_t vb = vld1_dup_f32(w); w += 1;
106             vacc01 = vfma_f32(vacc01, va01, vb);
107           } while (--nnz != 0);
108         }
109         float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
110         vout01 = vmax_f32(vout01, vget_low_f32(vmin));
111         vst1_f32(c, vout01);
112         c += m;
113       } while (--j != 0);
114       c -= m * n;
115       c += 2;
116       a += 2;
117     }
118     if (i & 1) {
119       const float*restrict w = weights;
120       const int32_t* dmap = widx_dmap;
121       const uint32_t* nnzmap = nidx_nnzmap;
122       size_t j = n;
123       do {
124         uint32_t nnz = *nnzmap++;
125         float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
126         if XNN_LIKELY(nnz != 0) {
127           do {
128             const intptr_t diff = *dmap++;
129             const float32x2_t va0 = vld1_dup_f32(a);
130             a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
131             const float32x2_t vb = vld1_dup_f32(w); w += 1;
132             vacc0 = vfma_f32(vacc0, va0, vb);
133           } while (--nnz != 0);
134         }
135         float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
136         vout0 = vmax_f32(vout0, vget_low_f32(vmin));
137         vst1_lane_f32(c, vout0, 0);
138         c += m;
139       } while (--j != 0);
140       c -= m * n;
141       c += 1;
142       a += 1;
143     }
144   }
145 }
146