1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJK" 8#include <assert.h> 9 10#include <arm_neon.h> 11 12#include <xnnpack/spmm.h> 13 14 15void xnn_f32_spmm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}${"_unroll" + str(UNROLL) if UNROLL > 1 else ""}( 16 uint32_t m, 17 uint32_t n, 18 const float*restrict a, 19 const float*restrict weights, 20 const int32_t*restrict widx_dmap, 21 const uint32_t*restrict nidx_nnzmap, 22 float*restrict c, 23 const union xnn_f32_output_params params[restrict static 1]) 24{ 25 assert(m != 0); 26 27 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 28 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 29 size_t i = m; 30 while XNN_LIKELY(i >= ${MR}) { 31 const float*restrict w = weights; 32 const int32_t* dmap = widx_dmap; 33 const uint32_t* nnzmap = nidx_nnzmap; 34 size_t j = n; 35 do { 36 uint32_t nnz = *nnzmap++; 37 $if UNROLL > 1: 38 float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1; 39 $for K in range(1, UNROLL): 40 float32x4_t vacc0123x${K} = vmovq_n_f32(0.0f); 41 $for M in range(4, MR, 4): 42 float32x4_t vacc${ABC[M:M+4]}x0 = vacc0123x0; 43 $for K in range(1, UNROLL): 44 float32x4_t vacc${ABC[M:M+4]}x${K} = vmovq_n_f32(0.0f); 45 for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) { 46 $for K in range(UNROLL): 47 const intptr_t diff${K} = dmap[${K}]; 48 dmap += ${UNROLL}; 49 $for K in range(UNROLL): 50 const float32x4_t va0123x${K} = vld1q_f32(a); 51 $for M in range(4, MR, 4): 52 const float32x4_t va${ABC[M:M+4]}x${K} = vld1q_f32(a + ${M}); 53 $if MR == 16: 54 __builtin_prefetch(a + 16); 55 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff${K}); 56 const float32x4_t vb${K} = vld1q_dup_f32(w); w += 1; 57 $for M in range(0, MR, 4): 58 vacc${ABC[M:M+4]}x${K} = vfmaq_f32(vacc${ABC[M:M+4]}x${K}, va${ABC[M:M+4]}x${K}, vb${K}); 59 } 60 $for M in range(0, MR, 4): 61 float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0; 62 $for K in range(1, UNROLL): 63 $for M in range(0, MR, 4): 64 vacc${ABC[M:M+4]} = vaddq_f32(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K}); 65 $else: 66 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 67 $for M in range(4, MR, 4): 68 float32x4_t vacc${ABC[M:M+4]} = vacc0123; 69 if XNN_LIKELY(nnz != 0) { 70 do { 71 const intptr_t diff = *dmap++; 72 const float32x4_t va0123 = vld1q_f32(a); 73 $for M in range(4, MR, 4): 74 const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M}); 75 $if MR == 16: 76 __builtin_prefetch(a + 16); 77 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 78 const float32x4_t vb = vld1q_dup_f32(w); w += 1; 79 $for M in range(0, MR, 4): 80 vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb); 81 } while (--nnz != 0); 82 } 83 $for M in range(0, MR, 4): 84 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 85 $for M in range(0, MR, 4): 86 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 87 vst1q_f32(c, vout0123); 88 $for M in range(4, MR, 4): 89 vst1q_f32(c + ${M}, vout${ABC[M:M+4]}); 90 c += m; 91 } while (--j != 0); 92 c -= m * n; 93 c += ${MR}; 94 a += ${MR}; 95 i -= ${MR}; 96 } 97 if XNN_UNLIKELY(i != 0) { 98 $for LOG2M in reversed(range((MR - 1).bit_length())): 99 $SUBMR = 1 << LOG2M 100 if (i & ${SUBMR}) { 101 const float*restrict w = weights; 102 const int32_t* dmap = widx_dmap; 103 const uint32_t* nnzmap = nidx_nnzmap; 104 size_t j = n; 105 do { 106 uint32_t nnz = *nnzmap++; 107 $if SUBMR <= 2: 108 float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1; 109 $else: 110 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 111 $for M in range(4, SUBMR, 4): 112 float32x4_t vacc${ABC[M:M+4]} = vacc0123; 113 if XNN_LIKELY(nnz != 0) { 114 do { 115 const intptr_t diff = *dmap++; 116 $if SUBMR == 1: 117 const float32x2_t va0 = vld1_dup_f32(a); 118 $elif SUBMR == 2: 119 const float32x2_t va01 = vld1_f32(a); 120 $else: 121 const float32x4_t va0123 = vld1q_f32(a); 122 $for M in range(4, SUBMR, 4): 123 const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M}); 124 a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 125 $if SUBMR <= 2: 126 const float32x2_t vb = vld1_dup_f32(w); w += 1; 127 $else: 128 const float32x4_t vb = vld1q_dup_f32(w); w += 1; 129 $if SUBMR <= 2: 130 vacc${ABC[0:SUBMR]} = vfma_f32(vacc${ABC[0:SUBMR]}, va${ABC[0:SUBMR]}, vb); 131 $else: 132 $for M in range(0, SUBMR, 4): 133 vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb); 134 } while (--nnz != 0); 135 } 136 $if SUBMR <= 2: 137 float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax)); 138 vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin)); 139 $if SUBMR == 1: 140 vst1_lane_f32(c, vout${ABC[0]}, 0); 141 $else: 142 vst1_f32(c, vout${ABC[0:SUBMR]}); 143 $else: 144 $for M in range(0, SUBMR, 4): 145 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 146 $for M in range(0, SUBMR, 4): 147 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 148 vst1q_f32(c, vout0123); 149 $for M in range(4, SUBMR, 4): 150 vst1q_f32(c + ${M}, vout${ABC[M:M+4]}); 151 c += m; 152 } while (--j != 0); 153 c -= m * n; 154 c += ${SUBMR}; 155 a += ${SUBMR}; 156 } 157 } 158} 159