• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  // Auto-generated file. Do not edit!
2  //   Template: src/f32-gemm/neon-shuffle.c.in
3  //   Generator: tools/xngen
4  //
5  // Copyright 2019 Google LLC
6  //
7  // This source code is licensed under the BSD-style license found in the
8  // LICENSE file in the root directory of this source tree.
9  
10  
11  #include <assert.h>
12  
13  #include <arm_neon.h>
14  
15  #include <xnnpack/gemm.h>
16  
17  
xnn_f32_gemm_minmax_ukernel_4x8s4__neon(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18  void xnn_f32_gemm_minmax_ukernel_4x8s4__neon(
19      size_t mr,
20      size_t nc,
21      size_t kc,
22      const float* restrict a,
23      size_t a_stride,
24      const float* restrict w,
25      float* restrict c,
26      size_t cm_stride,
27      size_t cn_stride,
28      const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29  {
30    assert(mr != 0);
31    assert(mr <= 4);
32    assert(nc != 0);
33    assert(kc != 0);
34    assert(kc % sizeof(float) == 0);
35    assert(a != NULL);
36    assert(w != NULL);
37    assert(c != NULL);
38  
39    const float* a0 = a;
40    float* c0 = c;
41    const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
42    float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
43    if XNN_UNPREDICTABLE(mr < 2) {
44      a1 = a0;
45      c1 = c0;
46    }
47    const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
48    float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
49    if XNN_UNPREDICTABLE(mr <= 2) {
50      a2 = a1;
51      c2 = c1;
52    }
53    const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
54    float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
55    if XNN_UNPREDICTABLE(mr != 4) {
56      a3 = a2;
57      c3 = c2;
58    }
59  
60    do {
61      float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
62      float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
63      float32x4_t vacc1x0123 = vacc0x0123;
64      float32x4_t vacc1x4567 = vacc0x4567;
65      float32x4_t vacc2x0123 = vacc0x0123;
66      float32x4_t vacc2x4567 = vacc0x4567;
67      float32x4_t vacc3x0123 = vacc0x0123;
68      float32x4_t vacc3x4567 = vacc0x4567;
69  
70      size_t k = kc;
71      while (k >= 4 * sizeof(float)) {
72        float32x4_t va0 = vld1q_f32(a0); a0 += 4;
73        float32x4_t va1 = vld1q_f32(a1); a1 += 4;
74        float32x4_t va2 = vld1q_f32(a2); a2 += 4;
75        float32x4_t va3 = vld1q_f32(a3); a3 += 4;
76  
77  
78        const float32x4_t vb0123c0 = vld1q_f32(w + 0);
79        const float32x4_t vb4567c0 = vld1q_f32(w + 4);
80  
81        vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123c0);
82        vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123c0);
83        vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123c0);
84        vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123c0);
85        vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567c0);
86        vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567c0);
87        vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567c0);
88        vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567c0);
89  
90        va0 = vextq_f32(va0, va0, 1);
91        va1 = vextq_f32(va1, va1, 1);
92        va2 = vextq_f32(va2, va2, 1);
93        va3 = vextq_f32(va3, va3, 1);
94  
95        const float32x4_t vb0123c1 = vld1q_f32(w + 8);
96        const float32x4_t vb4567c1 = vld1q_f32(w + 12);
97  
98        vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123c1);
99        vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123c1);
100        vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123c1);
101        vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123c1);
102        vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567c1);
103        vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567c1);
104        vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567c1);
105        vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567c1);
106  
107        va0 = vextq_f32(va0, va0, 1);
108        va1 = vextq_f32(va1, va1, 1);
109        va2 = vextq_f32(va2, va2, 1);
110        va3 = vextq_f32(va3, va3, 1);
111  
112        const float32x4_t vb0123c2 = vld1q_f32(w + 16);
113        const float32x4_t vb4567c2 = vld1q_f32(w + 20);
114  
115        vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123c2);
116        vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123c2);
117        vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123c2);
118        vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123c2);
119        vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567c2);
120        vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567c2);
121        vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567c2);
122        vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567c2);
123  
124        va0 = vextq_f32(va0, va0, 1);
125        va1 = vextq_f32(va1, va1, 1);
126        va2 = vextq_f32(va2, va2, 1);
127        va3 = vextq_f32(va3, va3, 1);
128  
129        const float32x4_t vb0123c3 = vld1q_f32(w + 24);
130        const float32x4_t vb4567c3 = vld1q_f32(w + 28);
131  
132        vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123c3);
133        vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123c3);
134        vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123c3);
135        vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123c3);
136        vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567c3);
137        vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567c3);
138        vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567c3);
139        vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567c3);
140  
141  
142        w += 32;
143        k -= 4 * sizeof(float);
144      }
145      if XNN_UNLIKELY(k != 0) {
146        do {
147          const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
148          const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
149          const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
150          const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
151  
152          const float32x4_t vb0123 = vld1q_f32(w); w += 4;
153          const float32x4_t vb4567 = vld1q_f32(w); w += 4;
154  
155          vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
156          vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
157          vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
158          vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
159          vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
160          vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
161          vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
162          vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
163  
164          k -= sizeof(float);
165        } while (k != 0);
166      }
167      const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
168      vacc0x0123 = vminq_f32(vacc0x0123, vmax);
169      vacc1x0123 = vminq_f32(vacc1x0123, vmax);
170      vacc2x0123 = vminq_f32(vacc2x0123, vmax);
171      vacc3x0123 = vminq_f32(vacc3x0123, vmax);
172      vacc0x4567 = vminq_f32(vacc0x4567, vmax);
173      vacc1x4567 = vminq_f32(vacc1x4567, vmax);
174      vacc2x4567 = vminq_f32(vacc2x4567, vmax);
175      vacc3x4567 = vminq_f32(vacc3x4567, vmax);
176  
177      const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
178      vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
179      vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
180      vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
181      vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
182      vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
183      vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
184      vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
185      vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
186  
187      if XNN_LIKELY(nc >= 8) {
188        vst1q_f32(c3, vacc3x0123);
189        vst1q_f32(c3 + 4, vacc3x4567);
190        c3 = (float*) ((uintptr_t) c3 + cn_stride);
191        vst1q_f32(c2, vacc2x0123);
192        vst1q_f32(c2 + 4, vacc2x4567);
193        c2 = (float*) ((uintptr_t) c2 + cn_stride);
194        vst1q_f32(c1, vacc1x0123);
195        vst1q_f32(c1 + 4, vacc1x4567);
196        c1 = (float*) ((uintptr_t) c1 + cn_stride);
197        vst1q_f32(c0, vacc0x0123);
198        vst1q_f32(c0 + 4, vacc0x4567);
199        c0 = (float*) ((uintptr_t) c0 + cn_stride);
200  
201        a3 = (const float*) ((uintptr_t) a3 - kc);
202        a2 = (const float*) ((uintptr_t) a2 - kc);
203        a1 = (const float*) ((uintptr_t) a1 - kc);
204        a0 = (const float*) ((uintptr_t) a0 - kc);
205  
206        nc -= 8;
207  
208      } else {
209        if (nc & 4) {
210          vst1q_f32(c3, vacc3x0123); c3 += 4;
211          vst1q_f32(c2, vacc2x0123); c2 += 4;
212          vst1q_f32(c1, vacc1x0123); c1 += 4;
213          vst1q_f32(c0, vacc0x0123); c0 += 4;
214  
215          vacc3x0123 = vacc3x4567;
216          vacc2x0123 = vacc2x4567;
217          vacc1x0123 = vacc1x4567;
218          vacc0x0123 = vacc0x4567;
219        }
220        float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
221        float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
222        float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
223        float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
224        if (nc & 2) {
225          vst1_f32(c3, vacc3x01); c3 += 2;
226          vst1_f32(c2, vacc2x01); c2 += 2;
227          vst1_f32(c1, vacc1x01); c1 += 2;
228          vst1_f32(c0, vacc0x01); c0 += 2;
229  
230          vacc3x01 = vget_high_f32(vacc3x0123);
231          vacc2x01 = vget_high_f32(vacc2x0123);
232          vacc1x01 = vget_high_f32(vacc1x0123);
233          vacc0x01 = vget_high_f32(vacc0x0123);
234        }
235        if (nc & 1) {
236          vst1_lane_f32(c3, vacc3x01, 0);
237          vst1_lane_f32(c2, vacc2x01, 0);
238          vst1_lane_f32(c1, vacc1x01, 0);
239          vst1_lane_f32(c0, vacc0x01, 0);
240        }
241  
242        nc = 0;
243      }
244    } while (nc != 0);
245  }
246