• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert NR == 2
7#include <assert.h>
8
9#include <arm_neon.h>
10
11#include <xnnpack/common.h>
12#include <xnnpack/gemm.h>
13
14
15void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_lane_ld64(
16    size_t mr,
17    size_t nc,
18    size_t kc,
19    const float* restrict a,
20    size_t a_stride,
21    const float*restrict w,
22    float*restrict c,
23    size_t cm_stride,
24    size_t cn_stride,
25    $if INC:
26      const float*restrict acc,
27    const union xnn_f32_output_params params[restrict static 1])
28{
29  assert(mr != 0);
30  assert(mr <= ${MR});
31  assert(nc != 0);
32  assert(kc != 0);
33  assert(kc % sizeof(float) == 0);
34  assert(a != NULL);
35  assert(w != NULL);
36  assert(c != NULL);
37  $if INC:
38    assert(acc != NULL);
39
40  const float* a0 = a;
41  float* c0 = c;
42  $for M in range(1, MR):
43    const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
44    float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
45    $if M % 2 == 0:
46      if XNN_UNPREDICTABLE(mr <= ${M}) {
47        a${M} = a${M-1};
48        c${M} = c${M-1};
49      }
50    $elif M + 1 == MR:
51      if XNN_UNPREDICTABLE(mr != ${M+1}) {
52        a${M} = a${M-1};
53        c${M} = c${M-1};
54      }
55    $else:
56      if XNN_UNPREDICTABLE(mr < ${M+1}) {
57        a${M} = a${M-1};
58        c${M} = c${M-1};
59      }
60
61  do {
62    $if INC:
63      $for M in range(0, MR):
64        float32x2_t vacc${M}x01 = vld1_f32(w); w += 2;
65    $else:
66      float32x2_t vacc0x01 = vld1_f32(w); w += 2;
67      $for M in range(1, MR):
68        float32x2_t vacc${M}x01 = vacc0x01;
69
70    size_t k = kc;
71    for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
72      $for M in range(MR):
73        const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
74
75      $for L in range(2):
76        const float32x2_t vb01c${L} = vld1_f32(w); w += 2;
77
78        $if FMA:
79          #if XNN_ARCH_ARM64
80            $for M in range(MR):
81              vacc${M}x01 = vfma_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
82          #else
83            $for M in range(MR):
84              const float32x2_t va${M}c${L} = vdup_lane_f32(va${M}, ${L});
85            $for M in range(MR):
86              vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}c${L}, vb01c${L});
87          #endif
88        $else:
89          $for M in range(MR):
90            vacc${M}x01 = vmla_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
91    }
92    if XNN_UNLIKELY(k != 0) {
93      $for M in range(MR):
94        const float32x2_t va${M} = vld1_dup_f32(a${M}); a${M} += 1;
95
96      const float32x2_t vb01 = vld1_f32(w); w += 2;
97
98      $for M in range(MR):
99        $if FMA:
100          vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}, vb01);
101        $else:
102          vacc${M}x01 = vmla_f32(vacc${M}x01, va${M}, vb01);
103    }
104
105    const float32x2_t vmax = vld1_dup_f32(&params->scalar.max);
106    $for N in range(0, NR, 4):
107      $for M in range(MR):
108        vacc${M}x01 = vmin_f32(vacc${M}x01, vmax);
109
110    const float32x2_t vmin = vld1_dup_f32(&params->scalar.min);
111    $for N in range(0, NR, 4):
112      $for M in range(MR):
113        vacc${M}x01 = vmax_f32(vacc${M}x01, vmin);
114
115    if XNN_LIKELY(nc >= ${NR}) {
116      $for M in range(MR):
117        vst1_f32(c${M}, vacc${M}x01);
118        c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
119
120      $for M in range(MR):
121        a${M} = (const float*) ((uintptr_t) a${M} - kc);
122
123      nc -= ${NR};
124    } else {
125      assert(nc == 1);
126      $for M in range(MR):
127        vst1_lane_f32(c${M}, vacc${M}x01, 0);
128
129      nc = 0;
130    }
131  } while (nc != 0);
132}
133