1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32" 9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/spmm.h> 15 16 17void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}${"_x" + str(UNROLL) if UNROLL > 1 else ""}( 18 size_t mc, 19 size_t nc, 20 const float*restrict input, 21 const float*restrict weights, 22 const int32_t*restrict widx_dmap, 23 const uint32_t*restrict nidx_nnzmap, 24 float*restrict output, 25 size_t output_stride, 26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 27{ 28 assert(mc != 0); 29 assert(mc % sizeof(float) == 0); 30 assert(nc != 0); 31 32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 34 size_t output_decrement = output_stride * nc - ${MR} * sizeof(float); 35 while XNN_LIKELY(mc >= ${MR} * sizeof(float)) { 36 const float*restrict w = weights; 37 const int32_t* dmap = widx_dmap; 38 const uint32_t* nnzmap = nidx_nnzmap; 39 size_t n = nc; 40 do { 41 uint32_t nnz = *nnzmap++; 42 $if UNROLL > 1: 43 float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1; 44 $for K in range(1, UNROLL): 45 float32x4_t vacc0123x${K} = vmovq_n_f32(0.0f); 46 $for M in range(4, MR, 4): 47 float32x4_t vacc${ABC[M:M+4]}x0 = vacc0123x0; 48 $for K in range(1, UNROLL): 49 float32x4_t vacc${ABC[M:M+4]}x${K} = vmovq_n_f32(0.0f); 50 for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) { 51 $for K in range(UNROLL): 52 const intptr_t diff${K} = dmap[${K}]; 53 dmap += ${UNROLL}; 54 $for K in range(UNROLL): 55 const float32x4_t vi0123x${K} = vld1q_f32(input); 56 $for M in range(4, MR, 4): 57 const float32x4_t vi${ABC[M:M+4]}x${K} = vld1q_f32(input + ${M}); 58 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K}); 59 $for M in range(0, MR, 16): 60 __builtin_prefetch(input + ${M+16}); 61 const float32x4_t vw${K} = vld1q_dup_f32(w); w += 1; 62 __builtin_prefetch(w + 32); 63 $for M in range(0, MR, 4): 64 vacc${ABC[M:M+4]}x${K} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}x${K}, vi${ABC[M:M+4]}x${K}, vw${K}); 65 } 66 $for M in range(0, MR, 4): 67 float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0; 68 $for K in range(1, UNROLL): 69 $for M in range(0, MR, 4): 70 vacc${ABC[M:M+4]} = vaddq_f32(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K}); 71 $else: 72 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 73 $for M in range(4, MR, 4): 74 float32x4_t vacc${ABC[M:M+4]} = vacc0123; 75 if XNN_LIKELY(nnz != 0) { 76 do { 77 const intptr_t diff = *dmap++; 78 const float32x4_t vi0123 = vld1q_f32(input); 79 $for M in range(4, MR, 4): 80 const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M}); 81 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 82 $for M in range(0, MR, 16): 83 __builtin_prefetch(input + ${M+16}); 84 const float32x4_t vw = vld1q_dup_f32(w); w += 1; 85 __builtin_prefetch(w + 32); 86 $for M in range(0, MR, 4): 87 vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw); 88 } while (--nnz != 0); 89 } 90 $for M in range(0, MR, 4): 91 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 92 $for M in range(0, MR, 4): 93 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 94 vst1q_f32(output, vout0123); 95 $for M in range(4, MR, 4): 96 vst1q_f32(output + ${M}, vout${ABC[M:M+4]}); 97 output = (float*restrict) ((uintptr_t) output + output_stride); 98 } while (--n != 0); 99 output = (float*restrict) ((uintptr_t) output - output_decrement); 100 input += ${MR}; 101 mc -= ${MR} * sizeof(float); 102 } 103 if XNN_UNLIKELY(mc != 0) { 104 $for LOG2M in reversed(range((MR - 1).bit_length())): 105 $SUBMR = 1 << LOG2M 106 $if SUBMR * 2 >= MR: 107 output_decrement += ${MR - SUBMR} * sizeof(float); 108 $else: 109 output_decrement += ${SUBMR} * sizeof(float); 110 if (mc & (${SUBMR} * sizeof(float))) { 111 const float*restrict w = weights; 112 const int32_t* dmap = widx_dmap; 113 const uint32_t* nnzmap = nidx_nnzmap; 114 size_t n = nc; 115 do { 116 uint32_t nnz = *nnzmap++; 117 $if SUBMR <= 2: 118 float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1; 119 $else: 120 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 121 $for M in range(4, SUBMR, 4): 122 float32x4_t vacc${ABC[M:M+4]} = vacc0123; 123 if XNN_LIKELY(nnz != 0) { 124 do { 125 const intptr_t diff = *dmap++; 126 $if SUBMR == 1: 127 const float32x2_t vi0 = vld1_dup_f32(input); 128 $elif SUBMR == 2: 129 const float32x2_t vi01 = vld1_f32(input); 130 $else: 131 const float32x4_t vi0123 = vld1q_f32(input); 132 $for M in range(4, SUBMR, 4): 133 const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M}); 134 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 135 $if SUBMR <= 2: 136 const float32x2_t vw = vld1_dup_f32(w); w += 1; 137 $else: 138 const float32x4_t vw = vld1q_dup_f32(w); w += 1; 139 $if SUBMR <= 2: 140 vacc${ABC[0:SUBMR]} = ${VMULADD_F32}(vacc${ABC[0:SUBMR]}, vi${ABC[0:SUBMR]}, vw); 141 $else: 142 $for M in range(0, SUBMR, 4): 143 vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw); 144 } while (--nnz != 0); 145 } 146 $if SUBMR <= 2: 147 float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax)); 148 vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin)); 149 $if SUBMR == 1: 150 vst1_lane_f32(output, vout${ABC[0]}, 0); 151 $else: 152 vst1_f32(output, vout${ABC[0:SUBMR]}); 153 $else: 154 $for M in range(0, SUBMR, 4): 155 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 156 $for M in range(0, SUBMR, 4): 157 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 158 vst1q_f32(output, vout0123); 159 $for M in range(4, SUBMR, 4): 160 vst1q_f32(output + ${M}, vout${ABC[M:M+4]}); 161 output = (float*restrict) ((uintptr_t) output + output_stride); 162 } while (--n != 0); 163 output = (float*restrict) ((uintptr_t) output - output_decrement); 164 input += ${SUBMR}; 165 } 166 } 167} 168