• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 
7 #include <cassert>
8 
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/gemm.h>
13 
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18   using Assembler::Assembler;
19  public:
20   void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params);
21 };
22 
23 
24 // void xnn_f32_gemm_minmax_ukernel_4x8__aarch32_neon_cortex_a55(
25 //     size_t mr,                            r0
26 //     size_t nc,                            r1
27 //     size_t kc,                            r2 -> r5
28 //     const uint8_t*restrict a,             r3
29 //     size_t a_stride,          sp + 96 -> (r7)
30 //     const void*restrict w,    sp + 100 -> r9
31 //     uint8_t*restrict c,       sp + 104 -> r11
32 //     size_t cm_stride,         sp + 108 -> (r6)
33 //     size_t cn_stride,         sp + 112 -> (r0)
34 //     minmax_params*params,     sp + 116 -> (r5)
35 
36 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
37 
38 // Register usage
39 
40 // r14 (lr) unused
41 
42 // A0   r3  d0
43 // A1  r12  d1
44 // A2  r10  d2
45 // A3   r7  d3
46 
47 // B    r9  d8,  d9, d10, d11
48 // B       d12, d13, d14, d15
49 
50 // C0  r11 d16-d17  q8  d18-d19  q9
51 // C1   r4 d20-d21 q10  d22-d23 q11
52 // C2   r8 d24-d25 q12  d26-d27 q13
53 // C3   r6 d28-d29 q14  d30-d31 q15
54 
55 // Clamp (r5) d4 d5 d6 d7
56 
57 // Converted from: src/f32-gemm/4x8-minmax-aarch32-neon-cortex-a55.S
generate(size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)58 void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
59 {
60   assert(nc_mod_nr < 8);
61   assert(kc != 0);
62   assert(kc % sizeof(float) == 0);
63 
64   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9;
65 
66   // Push 96 bytes
67   vpush({d8-d15}); // 64
68   push({r4, r5, r6, r7, r8, r9, r10, r11}); // +32 = 96
69 
70   ldr(r7, mem[sp, 96]); // a_stride
71   ldr(r11, mem[sp, 104]); // c
72   ldr(r6, mem[sp, 108]); // cm_stride
73   ldr(r9, mem[sp, 100]); // w
74 
75   // Clamp A and C pointers
76   cmp(r0, 2); // if mr >= 2
77   add(r12, r3, r7); //   a1 = a0 + a_stride
78   add(r4, r11, r6); //   c1 = c0 + cm_stride
79   movlo(r12, r3); // a1
80   movlo(r4, r11); // c1
81   // if mr > 2
82   add(r10, r12, r7); //   a2 = a1 + a_stride
83   add(r8, r4, r6); //   c2 = c1 + cm_stride
84   movls(r10, r12); // a2
85   movls(r8, r4); // c2
86 
87   cmp(r0, 4); // if mr >=4
88   add(r7, r10, r7); //   a3 = a2 + a_stride
89   add(r6, r8, r6); //   c3 = c2 + cm_stride
90   movlo(r7, r10); // a3
91   movlo(r6, r8); // c3
92 
93   align(8);
94   bind(l0);
95   // Load initial bias from w into accumulators
96   vldm(mem[r9]++, {d16-d19}); // Bias
97 
98   subs(r5, r2, 16); // kc - 16
99   pld(mem[r3, 0]); // Prefetch A
100   pld(mem[r3, 64]);
101   vmov(q10, q8);
102   pld(mem[r12, 0]);
103   pld(mem[r12, 64]);
104   vmov(q11, q9);
105   pld(mem[r10, 0]);
106   pld(mem[r10, 64]);
107   vmov(q12, q8);
108   pld(mem[r7, 0]);
109   pld(mem[r7, 64]);
110   vmov(q13, q9);
111   pld(mem[r9, 0]); // Prefetch B
112   pld(mem[r9, 64]);
113   vmov(q14, q8);
114   pld(mem[r9, 128]);
115   pld(mem[r9, 192]);
116   vmov(q15, q9);
117   pld(mem[r9, 256]);
118   pld(mem[r9, 320]);
119   blo(l4); // less than 4 channels?
120 
121   // Prologue
122   vld1_32({d0}, mem[r3]++); // A0
123   vld1_32({d1}, mem[r12]++); // A1
124   vld1_32({d2}, mem[r10]++); // A2
125   vld1_32({d3}, mem[r7]++); // A3
126   subs(r5, r5, 16);
127   vldm(mem[r9], {d8-d11}); // B0
128   vldr(d15, mem[r9, 56]); // B1CK 0
129   vldr(d13, mem[r9, 40]); // B1
130   blo(l2); // less than 4 channels?  skip main loop
131 
132   // Main loop - 4 floats of A (16 bytes)
133   // 32 FMA + 8 LD64 A + 8 LDR B
134   align(8);
135   bind(l1);
136   // First group of 16 FMA, Second group loads
137   // BLOCK 0
138   vmla_f32(q8, q4, d0[0]);
139   vld1_32({d4}, mem[r3]++); // A0
140   vmla_f32(q10, q4, d1[0]);
141   vld1_32({d5}, mem[r12]++); // A1
142   vmla_f32(q12, q4, d2[0]);
143 
144   // BLOCK 1
145   vmla_f32(q14, q4, d3[0]);
146   vldr(d12, mem[r9, 32]); // B1
147   vmla_f32(q9, q5, d0[0]);
148   vldr(d9, mem[r9, 72]); // B0
149   vmla_f32(q11, q5, d1[0]);
150 
151   // BLOCK 2
152   vmla_f32(q13, q5, d2[0]);
153   vld1_32({d6}, mem[r10]++); // A2
154   vmla_f32(q15, q5, d3[0]);
155   vld1_32({d7}, mem[r7]++); // A3
156   vmla_f32(q8, q6, d0[1]);
157 
158   // BLOCK 3
159   vmla_f32(q10, q6, d1[1]);
160   vldr(d14, mem[r9, 48]); // B1
161   vmla_f32(q12, q6, d2[1]);
162   vldr(d11, mem[r9, 88]); // B0
163   vmla_f32(q14, q6, d3[1]);
164 
165   // BLOCK 4
166   vmla_f32(q9, q7, d0[1]);
167   vldr(d8, mem[r9, 64]); // B0
168   vmla_f32(q11, q7, d1[1]);
169   vldr(d13, mem[r9, 104]); // B1
170   vmla_f32(q13, q7, d2[1]);
171   vldr(d10, mem[r9, 80]); // B0
172 
173   // BLOCK 5
174   vmla_f32(q15, q7, d3[1]);
175   vldr(d15, mem[r9, 120]); // B1
176 
177   // Second group of 16 FMA, First group of loads
178   // BLOCK 0
179   vmla_f32(q8, q4, d4[0]);
180   vld1_32({d0}, mem[r3]++); // A0
181   vmla_f32(q10, q4, d5[0]);
182   vld1_32({d1}, mem[r12]++); // A1
183   vmla_f32(q12, q4, d6[0]);
184 
185   // BLOCK 1
186   vmla_f32(q14, q4, d7[0]);
187   vldr(d12, mem[r9, 96]); // B1
188   vmla_f32(q9, q5, d4[0]);
189   vldr(d9, mem[r9, 136]); // B0
190   vmla_f32(q11, q5, d5[0]);
191 
192   // BLOCK 2
193   vmla_f32(q13, q5, d6[0]);
194   vld1_32({d2}, mem[r10]++); // A2
195   vmla_f32(q15, q5, d7[0]);
196   vld1_32({d3}, mem[r7]++); // A3
197   vmla_f32(q8, q6, d4[1]);
198 
199   // BLOCK 3
200   vmla_f32(q10, q6, d5[1]);
201   vldr(d14, mem[r9, 112]); // B1
202   vmla_f32(q12, q6, d6[1]);
203   vldr(d11, mem[r9, 152]); // B0
204   vmla_f32(q14, q6, d7[1]);
205   subs(r5, r5, 16);
206 
207   // BLOCK 4
208   vmla_f32(q9, q7, d4[1]);
209   vldr(d8, mem[r9, 128]); // B0
210   vmla_f32(q11, q7, d5[1]);
211   vldr(d13, mem[r9, 168]); // B1
212   vmla_f32(q13, q7, d6[1]);
213   vldr(d10, mem[r9, 144]); // B0
214 
215   // BLOCK 5
216   vmla_f32(q15, q7, d7[1]);
217   vldr(d15, mem[r9, 184]); // B1
218   add(r9, r9, 128); // B++
219   bhs(l1);
220 
221 
222   // Epilogue - 4 floats of A (16 bytes)
223   bind(l2);
224   // First group of 16 FMA, Second group loads
225   // BLOCK 0
226   vmla_f32(q8, q4, d0[0]);
227   vld1_32({d4}, mem[r3]++); // A0
228   vmla_f32(q10, q4, d1[0]);
229   vld1_32({d5}, mem[r12]++); // A1
230   vmla_f32(q12, q4, d2[0]);
231 
232   // BLOCK 1
233   vmla_f32(q14, q4, d3[0]);
234   vldr(d12, mem[r9, 32]); // B1
235   vmla_f32(q9, q5, d0[0]);
236   vldr(d9, mem[r9, 72]); // B0
237   vmla_f32(q11, q5, d1[0]);
238 
239   // BLOCK 2
240   vmla_f32(q13, q5, d2[0]);
241   vld1_32({d6}, mem[r10]++); // A2
242   vmla_f32(q15, q5, d3[0]);
243   vld1_32({d7}, mem[r7]++); // A3
244   vmla_f32(q8, q6, d0[1]);
245 
246   // BLOCK 3
247   vmla_f32(q10, q6, d1[1]);
248   vldr(d14, mem[r9, 48]); // B1
249   vmla_f32(q12, q6, d2[1]);
250   vldr(d11, mem[r9, 88]); // B0
251   vmla_f32(q14, q6, d3[1]);
252 
253   // BLOCK 4
254   vmla_f32(q9, q7, d0[1]);
255   vldr(d8, mem[r9, 64]); // B0
256   vmla_f32(q11, q7, d1[1]);
257   vldr(d13, mem[r9, 104]); // B1
258   vmla_f32(q13, q7, d2[1]);
259   vldr(d10, mem[r9, 80]); // B0
260 
261   // BLOCK 5
262   vmla_f32(q15, q7, d3[1]);
263   vldr(d15, mem[r9, 120]); // B1
264 
265   // Second group of 16 FMA, First group of loads
266   // BLOCK 0
267   vmla_f32(q8, q4, d4[0]);
268   vldr(d12, mem[r9, 96]); // B1
269   vmla_f32(q10, q4, d5[0]);
270   vmla_f32(q12, q4, d6[0]);
271 
272   // BLOCK 1
273   vmla_f32(q14, q4, d7[0]);
274   vldr(d14, mem[r9, 112]); // B1
275   vmla_f32(q9, q5, d4[0]);
276   vmla_f32(q11, q5, d5[0]);
277 
278   // BLOCK 2
279   vmla_f32(q13, q5, d6[0]);
280   vmla_f32(q15, q5, d7[0]);
281   vmla_f32(q8, q6, d4[1]);
282   add(r9, r9, 128); // B++
283 
284   // BLOCK 3
285   vmla_f32(q10, q6, d5[1]);
286   vmla_f32(q12, q6, d6[1]);
287   vmla_f32(q14, q6, d7[1]);
288   tst(r5, 15);
289 
290   // BLOCK 4
291   vmla_f32(q9, q7, d4[1]);
292   vmla_f32(q11, q7, d5[1]);
293   vmla_f32(q13, q7, d6[1]);
294 
295   // BLOCK 5
296   vmla_f32(q15, q7, d7[1]);
297 
298   // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
299   bne(l4);
300 
301   align(8);
302   bind(l3);
303   // Load params pointer
304   ldr(r0, mem[sp, 112]); // cn_stride
305   ldr(r5, mem[sp, 116]); // params
306   subs(r1, r1, 8);
307 
308   // Load min/max values
309   vld1r_32({d4,d5}, mem[r5]++);
310   vld1r_32({d6,d7}, mem[r5]);
311 
312   // Clamp
313   vmax_f32(q8, q8, q2);
314   vmax_f32(q9, q9, q2);
315   vmax_f32(q10, q10, q2);
316   vmax_f32(q11, q11, q2);
317   vmax_f32(q12, q12, q2);
318   vmax_f32(q13, q13, q2);
319   vmax_f32(q14, q14, q2);
320   vmax_f32(q15, q15, q2);
321   vmin_f32(q8, q8, q3);
322   vmin_f32(q9, q9, q3);
323   vmin_f32(q10, q10, q3);
324   vmin_f32(q11, q11, q3);
325   vmin_f32(q12, q12, q3);
326   vmin_f32(q13, q13, q3);
327   vmin_f32(q14, q14, q3);
328   vmin_f32(q15, q15, q3);
329 
330   // Store full 4 x 8
331   blo(l6);
332   vst1_32({d16-d19}, mem[r11], r0);
333   sub(r7, r7, r2);
334   vst1_32({d20-d23}, mem[r4], r0);
335   sub(r10, r10, r2);
336   vst1_32({d24-d27}, mem[r8], r0);
337   sub(r12, r12, r2);
338   vst1_32({d28-d31}, mem[r6], r0);
339   sub(r3, r3, r2);
340   bhi(l0);
341 
342   pop({r4, r5, r6, r7, r8, r9, r10, r11});
343   vpop({d8-d15});
344   bx(lr);
345 
346   align(8);
347   bind(l4);
348   // Is there a remainder?- 2 floats of A (8 bytes)
349   tst(r5, 8);
350   beq(l5);
351 
352   // Remainder - 2 floats of A (8 bytes)
353   vld1_32({d0}, mem[r3]++); // A0
354   vldm(mem[r9]++, {d8-d11}); // B0
355   vld1_32({d1}, mem[r12]++); // A1
356   vld1_32({d2}, mem[r10]++); // A2
357   vld1_32({d3}, mem[r7]++); // A3
358 
359   vmla_f32(q8, q4, d0[0]);
360   vmla_f32(q9, q5, d0[0]);
361   vmla_f32(q10, q4, d1[0]);
362   vmla_f32(q11, q5, d1[0]);
363   vldm(mem[r9]++, {d12-d15}); // B1
364   vmla_f32(q12, q4, d2[0]);
365   vmla_f32(q13, q5, d2[0]);
366   vmla_f32(q14, q4, d3[0]);
367   vmla_f32(q15, q5, d3[0]);
368   vmla_f32(q8, q6, d0[1]);
369   vmla_f32(q9, q7, d0[1]);
370   vmla_f32(q10, q6, d1[1]);
371   vmla_f32(q11, q7, d1[1]);
372   vmla_f32(q12, q6, d2[1]);
373   vmla_f32(q13, q7, d2[1]);
374   vmla_f32(q14, q6, d3[1]);
375   vmla_f32(q15, q7, d3[1]);
376 
377   // Is there a remainder?- 1 float of A (4 bytes)
378   tst(r5, 4);
379   beq(l3);
380 
381   bind(l5);
382   // Remainder- 1 float of A (4 bytes)
383   vldm(mem[r3]++, {s0}); // A0
384   vldm(mem[r9]++, {d8-d11}); // B0
385   vldm(mem[r12]++, {s2}); // A1
386   vldm(mem[r10]++, {s4}); // A2
387   vldm(mem[r7]++, {s6}); // A3
388   vmla_f32(q8, q4, d0[0]);
389   vmla_f32(q9, q5, d0[0]);
390   vmla_f32(q10, q4, d1[0]);
391   vmla_f32(q11, q5, d1[0]);
392   vmla_f32(q12, q4, d2[0]);
393   vmla_f32(q13, q5, d2[0]);
394   vmla_f32(q14, q4, d3[0]);
395   vmla_f32(q15, q5, d3[0]);
396   b(l3);
397 
398   // Store odd width
399   bind(l6);
400   tst(r1, 4);
401   beq(l7);
402   vst1_32({d16-d17}, mem[r11]++);
403   vst1_32({d20-d21}, mem[r4]++);
404   vmov(q8, q9);
405   vmov(q10, q11);
406   vst1_32({d24-d25}, mem[r8]++);
407   vst1_32({d28-d29}, mem[r6]++);
408   vmov(q12, q13);
409   vmov(q14, q15);
410 
411   bind(l7);
412   tst(r1, 2);
413   beq(l8);
414   vst1_32({d16}, mem[r11]++);
415   vst1_32({d20}, mem[r4]++);
416   vmov(d16, d17);
417   vmov(d20, d21);
418   vst1_32({d24}, mem[r8]++);
419   vst1_32({d28}, mem[r6]++);
420   vmov(d24, d25);
421   vmov(d28, d29);
422 
423   bind(l8);
424   tst(r1, 1);
425   beq(l9);
426   vst1_32({d16[0]}, mem[r11]);
427   vst1_32({d20[0]}, mem[r4]);
428   vst1_32({d24[0]}, mem[r8]);
429   vst1_32({d28[0]}, mem[r6]);
430 
431   bind(l9);
432   pop({r4, r5, r6, r7, r8, r9, r10, r11});
433   vpop({d8-d15});
434   bx(lr);
435 }
436 }  // namespace
437 }  // aarch32
438 }  // xnnpack
439 
xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a55(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)440 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a55(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
441   using namespace xnnpack::aarch32;
442   Generator g(code);
443   assert(params != nullptr);
444   g.generate(max_mr, nc_mod_nr, kc, nullptr);
445   g.finalize();
446   if (g.error() != xnnpack::Error::kNoError) {
447     return xnn_status_invalid_state;
448   }
449   return xnn_status_success;
450 }
451