1// Copyright (c) Facebook, Inc. and its affiliates. 2// All rights reserved. 3// 4// Copyright 2019 Google LLC 5// 6// This source code is licensed under the BSD-style license found in the 7// LICENSE file in the root directory of this source tree. 8 9$assert NR % 8 == 0 10$ABC = "0123456789ABCDEFGHIJKLMN" 11#include <assert.h> 12 13#include <arm_neon.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/gemm.h> 17 18 19void xnn_f16_gemm_ukernel_${MR}x${NR}__neonfp16arith_ld64( 20 size_t mr, 21 size_t nc, 22 size_t kc, 23 const void* restrict a, 24 size_t a_stride, 25 const void* restrict w, 26 void* restrict c, 27 size_t cm_stride, 28 size_t cn_stride, 29 const struct xnn_f16_output_params params[restrict static 1]) 30{ 31 assert(mr != 0); 32 assert(mr <= ${MR}); 33 assert(nc != 0); 34 assert(kc != 0); 35 assert(kc % sizeof(__fp16) == 0); 36 37 const __fp16* a0 = a; 38 __fp16* c0 = c; 39 $for M in range(1, MR): 40 const __fp16* a${M} = (const __fp16*) ((uintptr_t) a${M-1} + a_stride); 41 __fp16* c${M} = (__fp16*) ((uintptr_t) c${M-1} + cm_stride); 42 $if M % 2 == 0: 43 if XNN_UNPREDICTABLE(mr <= ${M}) { 44 a${M} = a${M-1}; 45 c${M} = c${M-1}; 46 } 47 $elif M + 1 == MR: 48 if XNN_UNPREDICTABLE(mr != ${M+1}) { 49 a${M} = a${M-1}; 50 c${M} = c${M-1}; 51 } 52 $else: 53 if XNN_UNPREDICTABLE(mr < ${M+1}) { 54 a${M} = a${M-1}; 55 c${M} = c${M-1}; 56 } 57 58 do { 59 $for N in range(0, NR, 8): 60 float16x8_t vacc0x${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 61 $for M in range(1, MR): 62 $for N in range(0, NR, 8): 63 float16x8_t vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]}; 64 65 size_t k = kc; 66 while (k >= 4 * sizeof(__fp16)) { 67 $for M in range(MR): 68 const float16x4_t va${M} = vld1_f16(a${M}); a${M} += 4; 69 70 $for L in range(4): 71 $for N in range(0, NR, 8): 72 const float16x8_t vb${ABC[N:N+8]}c${L} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 73 74 #if XNN_ARCH_ARM64 75 $for M in range(MR): 76 vacc${M}x${ABC[N:N+8]} = vfmaq_lane_f16(vacc${M}x${ABC[N:N+8]}, vb${ABC[N:N+8]}c${L}, va${M}, ${L}); 77 #else 78 $for M in range(MR): 79 const float16x8_t va${M}c${L} = vdupq_lane_f16(va${M}, ${L}); 80 81 $for N in range(0, NR, 8): 82 $for M in range(MR): 83 vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}c${L}, vb${ABC[N:N+8]}c${L}); 84 #endif 85 86 k -= 4 * sizeof(__fp16); 87 } 88 if XNN_UNLIKELY(k != 0) { 89 do { 90 $for M in range(MR): 91 const float16x8_t va${M} = vld1q_dup_f16(a${M}); a${M} += 1; 92 93 $for N in range(0, NR, 8): 94 const float16x8_t vb${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 95 96 $for N in range(0, NR, 8): 97 $for M in range(MR): 98 vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}, vb${ABC[N:N+8]}); 99 100 k -= sizeof(__fp16); 101 } while (k != 0); 102 } 103 104 const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); 105 $for N in range(0, NR, 8): 106 $for M in range(MR): 107 vacc${M}x${ABC[N:N+8]} = vmulq_f16(vacc${M}x${ABC[N:N+8]}, vscale); 108 109 const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); 110 $for N in range(0, NR, 8): 111 $for M in range(MR): 112 vacc${M}x${ABC[N:N+8]} = vminq_f16(vacc${M}x${ABC[N:N+8]}, vmax); 113 114 const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); 115 $for N in range(0, NR, 8): 116 $for M in range(MR): 117 vacc${M}x${ABC[N:N+8]} = vmaxq_f16(vacc${M}x${ABC[N:N+8]}, vmin); 118 119 if XNN_LIKELY(nc >= ${NR}) { 120 $for M in range(MR): 121 vst1q_f16(c${M}, vacc${M}x${ABC[0:8]}); 122 $for N in range(8, NR, 8): 123 vst1q_f16(c${M} + ${N}, vacc${M}x${ABC[N:N+8]}); 124 c${M} = (__fp16*) ((uintptr_t) c${M} + cn_stride); 125 126 $for M in range(MR): 127 a${M} = (const __fp16*) ((uintptr_t) a${M} - kc); 128 129 nc -= ${NR}; 130 } else { 131 $for LOG2N in reversed(range(NR.bit_length())): 132 $if NR != 1 << LOG2N: 133 if (nc & ${1 << LOG2N}) { 134 $if LOG2N >= 3: 135 $for N in range(0, 1 << LOG2N, 8): 136 $for M in range(MR): 137 vst1q_f16(c${M}, vacc${M}x${ABC[N:N+8]}); c${M} += 8; 138 139 $for M in range(MR): 140 $for N in range(0, 1 << (LOG2N - 1), 4): 141 vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; 142 $elif LOG2N == 2: 143 $for M in range(MR): 144 vst1_f16(c${M}, vacc${M}x${ABC[0:4]}); c${M} += 4; 145 146 $for M in range(MR): 147 vacc${M}x${ABC[0:4]} = vget_high_f16(vacc${M}x${ABC[0:8]}); 148 $elif LOG2N == 1: 149 $for M in range(MR): 150 vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_f16(vacc${M}x${ABC[0:4]}), 0); c${M} += 2; 151 152 $for M in range(MR): 153 vacc${M}x${ABC[0:4]} = vext_f16(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]}, 2); 154 $elif LOG2N == 0: 155 $for M in range(MR): 156 vst1_lane_f16(c${M}, vacc${M}x${ABC[0:4]}, 0); 157 } 158 $if LOG2N == 3: 159 $for M in range(MR): 160 float16x4_t vacc${M}x${ABC[0:4]} = vget_low_f16(vacc${M}x${ABC[0:8]}); 161 162 nc = 0; 163 } 164 } while (nc != 0); 165} 166