1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert MR % 4 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32" 9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/spmm.h> 15 16 17void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_pipelined( 18 size_t mc, 19 size_t nc, 20 const float*restrict input, 21 const float*restrict weights, 22 const int32_t*restrict widx_dmap, 23 const uint32_t*restrict nidx_nnzmap, 24 float*restrict output, 25 size_t output_stride, 26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 27{ 28 assert(mc != 0); 29 assert(mc % sizeof(float) == 0); 30 assert(nc != 0); 31 32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 34 size_t output_decrement = output_stride * nc - ${MR} * sizeof(float); 35 while XNN_LIKELY(mc >= ${MR} * sizeof(float)) { 36 const float*restrict w = weights; 37 const int32_t* dmap = widx_dmap; 38 const uint32_t* nnzmap = nidx_nnzmap; 39 float32x4_t vw = vld1q_dup_f32(w); w += 1; 40 intptr_t diff = *dmap++; 41 float32x4_t vi0123 = vld1q_f32(input); 42 $for M in range(4, MR, 4): 43 float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M}); 44 size_t n = nc; 45 do { 46 uint32_t nnz = *nnzmap++; 47 $for M in range(0, MR, 4): 48 float32x4_t vacc${ABC[M:M+4]} = vw; 49 vw = vld1q_dup_f32(w); w += 1; 50 if XNN_LIKELY(nnz != 0) { 51 do { 52 $for M in range(0, MR, 4): 53 vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw); 54 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 55 $for M in range(0, MR, 16): 56 __builtin_prefetch(input + ${M+16}); 57 diff = *dmap++; 58 vw = vld1q_dup_f32(w); w += 1; 59 __builtin_prefetch(w + 32); 60 vi0123 = vld1q_f32(input); 61 $for M in range(4, MR, 4): 62 vi${ABC[M:M+4]} = vld1q_f32(input + ${M}); 63 } while (--nnz != 0); 64 } 65 $for M in range(0, MR, 4): 66 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 67 $for M in range(0, MR, 4): 68 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 69 vst1q_f32(output, vout0123); 70 $for M in range(4, MR, 4): 71 vst1q_f32(output + ${M}, vout${ABC[M:M+4]}); 72 output = (float*restrict) ((uintptr_t) output + output_stride); 73 } while (--n != 0); 74 output = (float*restrict) ((uintptr_t) output - output_decrement); 75 input += ${MR}; 76 mc -= ${MR} * sizeof(float); 77 } 78 if XNN_UNLIKELY(mc != 0) { 79 $for LOG2M in reversed(range((MR - 1).bit_length())): 80 $SUBMR = 1 << LOG2M 81 $if SUBMR * 2 >= MR: 82 output_decrement += ${MR - SUBMR} * sizeof(float); 83 $else: 84 output_decrement += ${SUBMR} * sizeof(float); 85 if (mc & (${SUBMR} * sizeof(float))) { 86 const float*restrict w = weights; 87 const int32_t* dmap = widx_dmap; 88 const uint32_t* nnzmap = nidx_nnzmap; 89 size_t n = nc; 90 do { 91 uint32_t nnz = *nnzmap++; 92 $if SUBMR <= 2: 93 float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1; 94 $else: 95 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 96 $for M in range(4, SUBMR, 4): 97 float32x4_t vacc${ABC[M:M+4]} = vacc0123; 98 if XNN_LIKELY(nnz != 0) { 99 do { 100 const intptr_t diff = *dmap++; 101 $if SUBMR == 1: 102 const float32x2_t vi0 = vld1_dup_f32(input); 103 $elif SUBMR == 2: 104 const float32x2_t vi01 = vld1_f32(input); 105 $else: 106 const float32x4_t vi0123 = vld1q_f32(input); 107 $for M in range(4, SUBMR, 4): 108 const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M}); 109 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff); 110 $for M in range(0, MR, 16): 111 __builtin_prefetch(input + ${M+16}); 112 $if SUBMR <= 2: 113 const float32x2_t vb = vld1_dup_f32(w); w += 1; 114 $else: 115 const float32x4_t vb = vld1q_dup_f32(w); w += 1; 116 __builtin_prefetch(w + 32); 117 $if SUBMR <= 2: 118 vacc${ABC[0:SUBMR]} = ${VMULADD_F32}(vacc${ABC[0:SUBMR]}, vi${ABC[0:SUBMR]}, vb); 119 $else: 120 $for M in range(0, SUBMR, 4): 121 vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vb); 122 } while (--nnz != 0); 123 } 124 $if SUBMR <= 2: 125 float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax)); 126 vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin)); 127 $if SUBMR == 1: 128 vst1_lane_f32(output, vout${ABC[0]}, 0); 129 $else: 130 vst1_f32(output, vout${ABC[0:SUBMR]}); 131 $else: 132 $for M in range(0, SUBMR, 4): 133 float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax); 134 $for M in range(0, SUBMR, 4): 135 vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin); 136 vst1q_f32(output, vout0123); 137 $for M in range(4, SUBMR, 4): 138 vst1q_f32(output + ${M}, vout${ABC[M:M+4]}); 139 output = (float*restrict) ((uintptr_t) output + output_stride); 140 } while (--n != 0); 141 output = (float*restrict) ((uintptr_t) output - output_decrement); 142 input += ${SUBMR}; 143 } 144 } 145} 146