1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24 #ifdef __ARM_FEATURE_SVE
25 #ifdef ARM_COMPUTE_ENABLE_SME2
26
27 #include "arm_gemm.hpp"
28 #include "../../utils.hpp"
29 #include "../../bfloat.hpp"
30
31 #include <cassert>
32 #include <limits>
33
34 namespace arm_gemm {
35
sme2_gemv_fp32bf16fp32_dot_16VL(const float * A_ptr,const bfloat16 * B_ptr,float * output_ptr,size_t N,size_t K,const float * bias,Activation act,bool)36 void sme2_gemv_fp32bf16fp32_dot_16VL (
37 const float *A_ptr, const bfloat16 *B_ptr, float *output_ptr,
38 size_t N, size_t K,
39 const float *bias, Activation act, bool
40 )
41 {
42 struct KernelArgs {
43 float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
44 float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
45 const bfloat16 *B_ptr = {};
46 size_t output_offset = {};
47 unsigned int input_initial_col = {};
48 } ka;
49
50 unsigned long flags=0;
51 ka.B_ptr = B_ptr;
52 switch(act.type) {
53 default:
54 case Activation::Type::None:
55 break;
56 case Activation::Type::BoundedReLU:
57 ka.maxval = static_cast<float>(act.param1);
58 /* fall through */
59 case Activation::Type::ReLU:
60 ka.minval = 0;
61 flags |= 0x2;
62 break;
63 }
64 __asm__ __volatile__(
65 "ptrue p2.b\n"
66 ".inst 0xd503477f // SMSTART ZA\n"
67 "cntw x28, ALL, MUL #4\n"
68 "add x27, %x[N], x28\n"
69 "sub x27, x27, #0x1\n"
70 "udiv x27, x27, x28\n"
71 "add x21, x27, #0x3\n"
72 "and x21, x21, #0xfffffffffffffffc\n"
73 "mul x21, x21, x28\n"
74 "mul x21, x21, %x[K]\n"
75 "mov x9, #0x0\n"
76 "mov x26, #0x4\n"
77 "mov x25, %x[B_ptr]\n"
78 "mov x24, %x[output_ptr]\n"
79 "ptrue p2.b\n"
80 ".inst 0x25207811 // ptrue pn9.b\n"
81 "lsl x21, x21, #0x1\n"
82 "mov x20, #0x1\n"
83 "1:" // RHS size check loop
84 "cmp x21, #0x200000\n"
85 "blt 2f\n"
86 "tbnz x21, #0, 3f\n"
87 "lsr x21, x21, #0x1\n"
88 "lsl x20, x20, #0x1\n"
89 "b 1b\n"
90 "2:" // RHS do prefetch
91 "lsl x19, x21, #0x26\n"
92 "sub x20, x20, #0x1\n"
93 "lsl x20, x20, #0x16\n"
94 "orr x21, x21, x19\n"
95 "orr x21, x21, x20\n"
96 ".inst 0xf8b54b3a // rprfm pldonce, x21, [x25]\n"
97 "3:" // RHS prefetch exit
98 "mov x23, %x[bias]\n"
99 "4:" // Column loop
100 "cmp x27, #0x4\n"
101 "bge 28f\n"
102 "cmp x27, #0x2\n"
103 "bgt 20f\n"
104 "beq 12f\n"
105 "mov x22, %x[A_ptr]\n"
106 "lsl x21, %x[K], #0x2\n"
107 "mov x19, %x[N]\n"
108 "mov x20, %x[K]\n"
109 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
110 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
111 "cbz x23, 5f\n"
112 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
113 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
114 "b 6f\n"
115 "5:" // Width 1: no bias
116 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
117 "6:" // Width 1: setup done
118 "cmp x20, #0x8\n"
119 "ble 8f\n"
120 "7:" // Width 1: Multiply loop: Main loop head
121 "whilelt p1.s, XZR, x20\n"
122 "whilelt p0.s, x26, x20\n"
123 "ld1rqw { z0.s }, p1/Z, [x22]\n"
124 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
125 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
126 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
127 "uzp1 z0.h, z0.h, z0.h\n"
128 "sub x20, x20, #0x8\n"
129 "uzp1 z11.h, z11.h, z11.h\n"
130 "trn1 z0.d, z0.d, z11.d\n"
131 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
132 "addvl x25, x25, #16\n"
133 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
134 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
135 "addvl x25, x25, #16\n"
136 "cmp x20, #0x8\n"
137 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
138 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
139 "addvl x25, x25, #16\n"
140 "add x22, x22, #0x20\n"
141 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
142 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
143 "addvl x25, x25, #16\n"
144 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
145 "bgt 7b\n"
146 "8:" // Width 1: Multiply loop: Single iteration only
147 "whilelt p1.s, XZR, x20\n"
148 "whilelt p0.s, x26, x20\n"
149 "ld1rqw { z0.s }, p1/Z, [x22]\n"
150 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
151 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
152 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
153 "uzp1 z0.h, z0.h, z0.h\n"
154 "subs x20, x20, #0x2\n"
155 "uzp1 z11.h, z11.h, z11.h\n"
156 "trn1 z0.d, z0.d, z11.d\n"
157 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
158 "add x22, x22, #0x20\n"
159 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
160 "addvl x25, x25, #16\n"
161 "ble 9f\n"
162 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
163 "subs x20, x20, #0x2\n"
164 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
165 "addvl x25, x25, #16\n"
166 "ble 9f\n"
167 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
168 "subs x20, x20, #0x2\n"
169 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
170 "addvl x25, x25, #16\n"
171 "ble 9f\n"
172 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
173 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
174 "addvl x25, x25, #16\n"
175 "9:" // Width 1: Multiply loop: multiply skip
176 "tbz %x[flags], #1, 10f\n"
177 "add x20, %x[args_ptr], %[offset_min]\n"
178 "add x19, %x[args_ptr], %[offset_max]\n"
179 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
180 "ld1rw { z29.s }, p2/Z, [x20]\n"
181 "ld1rw { z18.s }, p2/Z, [x19]\n"
182 ".inst 0xc1b2cba8 // fclamp { z8.s-z11.s }, z29.s, z18.s\n"
183 ".inst 0xa060c308 // st1w { z8.s-z11.s }, p8, [x24]\n"
184 "addvl x24, x24, #4\n"
185 "b 11f\n"
186 "10:" // Width 1: No activation
187 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
188 ".inst 0xa060c308 // st1w { z8.s-z11.s }, p8, [x24]\n"
189 "addvl x24, x24, #4\n"
190 "11:" // Width 1: Output done
191 "b 36f\n"
192 "12:" // Width 2
193 "mov x22, %x[A_ptr]\n"
194 "lsl x21, %x[K], #0x2\n"
195 "sub x19, %x[N], x28\n"
196 "mov x20, %x[K]\n"
197 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
198 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
199 "cbz x23, 13f\n"
200 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
201 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
202 ".inst 0xa041c6f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
203 ".inst 0xc0042e01 // mova za.d[x9, #1], { z16.d-z19.d }\n"
204 "b 14f\n"
205 "13:" // Width 2: no bias
206 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
207 "14:" // Width 2: setup done
208 "cmp x20, #0x8\n"
209 "ble 16f\n"
210 "15:" // Width 2: Multiply loop: Main loop head
211 "whilelt p1.s, XZR, x20\n"
212 "whilelt p0.s, x26, x20\n"
213 "ld1rqw { z0.s }, p1/Z, [x22]\n"
214 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
215 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
216 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
217 "uzp1 z0.h, z0.h, z0.h\n"
218 "sub x20, x20, #0x8\n"
219 "uzp1 z11.h, z11.h, z11.h\n"
220 "trn1 z0.d, z0.d, z11.d\n"
221 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
222 "cmp x20, #0x8\n"
223 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
224 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
225 "addvl x25, x25, #16\n"
226 "add x22, x22, #0x20\n"
227 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
228 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
229 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
230 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
231 "addvl x25, x25, #16\n"
232 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
233 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
234 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
235 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
236 "addvl x25, x25, #16\n"
237 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
238 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
239 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
240 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
241 "addvl x25, x25, #16\n"
242 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
243 "bgt 15b\n"
244 "16:" // Width 2: Multiply loop: Single iteration only
245 "whilelt p1.s, XZR, x20\n"
246 "whilelt p0.s, x26, x20\n"
247 "ld1rqw { z0.s }, p1/Z, [x22]\n"
248 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
249 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
250 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
251 "uzp1 z0.h, z0.h, z0.h\n"
252 "subs x20, x20, #0x2\n"
253 "uzp1 z11.h, z11.h, z11.h\n"
254 "trn1 z0.d, z0.d, z11.d\n"
255 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
256 "add x22, x22, #0x20\n"
257 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
258 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
259 "addvl x25, x25, #16\n"
260 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
261 "ble 17f\n"
262 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
263 "subs x20, x20, #0x2\n"
264 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
265 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
266 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
267 "addvl x25, x25, #16\n"
268 "ble 17f\n"
269 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
270 "subs x20, x20, #0x2\n"
271 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
272 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
273 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
274 "addvl x25, x25, #16\n"
275 "ble 17f\n"
276 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
277 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
278 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
279 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
280 "addvl x25, x25, #16\n"
281 "17:" // Width 2: Multiply loop: multiply skip
282 "tbz %x[flags], #1, 18f\n"
283 "add x20, %x[args_ptr], %[offset_min]\n"
284 "add x19, %x[args_ptr], %[offset_max]\n"
285 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
286 "ld1rw { z29.s }, p2/Z, [x20]\n"
287 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
288 "ld1rw { z18.s }, p2/Z, [x19]\n"
289 ".inst 0xc1b2cba8 // fclamp { z8.s-z11.s }, z29.s, z18.s\n"
290 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
291 ".inst 0xc1b2cbac // fclamp { z12.s-z15.s }, z29.s, z18.s\n"
292 ".inst 0xa061c30c // st1w { z12.s-z15.s }, p8, [x24, #0x4, MUL VL]\n"
293 "addvl x24, x24, #8\n"
294 "b 19f\n"
295 "18:" // Width 2: No activation
296 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
297 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
298 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
299 ".inst 0xa061c30c // st1w { z12.s-z15.s }, p8, [x24, #0x4, MUL VL]\n"
300 "addvl x24, x24, #8\n"
301 "19:" // Width 2: Output done
302 "b 36f\n"
303 "20:" // Width 3
304 "mov x19, #0x2\n"
305 "mov x22, %x[A_ptr]\n"
306 "lsl x21, %x[K], #0x2\n"
307 "msub x19, x28, x19, %x[N]\n"
308 "mov x20, %x[K]\n"
309 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
310 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
311 "cbz x23, 21f\n"
312 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
313 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
314 ".inst 0xa041c6f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
315 ".inst 0xc0042e01 // mova za.d[x9, #1], { z16.d-z19.d }\n"
316 ".inst 0xa042c6fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
317 ".inst 0xc0042f82 // mova za.d[x9, #2], { z28.d-z31.d }\n"
318 "b 22f\n"
319 "21:" // Width 3: no bias
320 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
321 "22:" // Width 3: setup done
322 "cmp x20, #0x8\n"
323 "ble 24f\n"
324 "23:" // Width 3: Multiply loop: Main loop head
325 "whilelt p1.s, XZR, x20\n"
326 "whilelt p0.s, x26, x20\n"
327 "ld1rqw { z0.s }, p1/Z, [x22]\n"
328 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
329 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
330 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
331 "uzp1 z0.h, z0.h, z0.h\n"
332 "sub x20, x20, #0x8\n"
333 "uzp1 z11.h, z11.h, z11.h\n"
334 "trn1 z0.d, z0.d, z11.d\n"
335 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
336 "cmp x20, #0x8\n"
337 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
338 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
339 "add x22, x22, #0x20\n"
340 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
341 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
342 "addvl x25, x25, #16\n"
343 ".inst 0xc150b21a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[0]\n"
344 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
345 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
346 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
347 ".inst 0xa042a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
348 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
349 "addvl x25, x25, #16\n"
350 ".inst 0xc150b59a // bfdot za.s[x9, 2], { z12.h-z15.h }, z0.h[1]\n"
351 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
352 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
353 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
354 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
355 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
356 "addvl x25, x25, #16\n"
357 ".inst 0xc150ba1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[2]\n"
358 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
359 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
360 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
361 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
362 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
363 "addvl x25, x25, #16\n"
364 ".inst 0xc150be1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[3]\n"
365 "bgt 23b\n"
366 "24:" // Width 3: Multiply loop: Single iteration only
367 "whilelt p1.s, XZR, x20\n"
368 "whilelt p0.s, x26, x20\n"
369 "ld1rqw { z0.s }, p1/Z, [x22]\n"
370 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
371 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
372 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
373 "uzp1 z0.h, z0.h, z0.h\n"
374 "subs x20, x20, #0x2\n"
375 "uzp1 z11.h, z11.h, z11.h\n"
376 "trn1 z0.d, z0.d, z11.d\n"
377 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
378 "add x22, x22, #0x20\n"
379 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
380 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
381 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
382 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
383 "addvl x25, x25, #16\n"
384 ".inst 0xc150b21a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[0]\n"
385 "ble 25f\n"
386 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
387 "subs x20, x20, #0x2\n"
388 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
389 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
390 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
391 ".inst 0xa042a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
392 ".inst 0xc150b59a // bfdot za.s[x9, 2], { z12.h-z15.h }, z0.h[1]\n"
393 "addvl x25, x25, #16\n"
394 "ble 25f\n"
395 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
396 "subs x20, x20, #0x2\n"
397 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
398 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
399 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
400 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
401 ".inst 0xc150ba1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[2]\n"
402 "addvl x25, x25, #16\n"
403 "ble 25f\n"
404 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
405 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
406 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
407 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
408 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
409 ".inst 0xc150be1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[3]\n"
410 "addvl x25, x25, #16\n"
411 "25:" // Width 3: Multiply loop: multiply skip
412 "tbz %x[flags], #1, 26f\n"
413 "add x20, %x[args_ptr], %[offset_min]\n"
414 "add x19, %x[args_ptr], %[offset_max]\n"
415 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
416 "ld1rw { z29.s }, p2/Z, [x20]\n"
417 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
418 "ld1rw { z18.s }, p2/Z, [x19]\n"
419 ".inst 0xc1b2cba8 // fclamp { z8.s-z11.s }, z29.s, z18.s\n"
420 ".inst 0xc0062c44 // mova { z4.d-z7.d }, za.d[x9, #2]\n"
421 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
422 ".inst 0xc1b2cbac // fclamp { z12.s-z15.s }, z29.s, z18.s\n"
423 ".inst 0xa061c70c // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL]\n"
424 ".inst 0xc1b2cba4 // fclamp { z4.s-z7.s }, z29.s, z18.s\n"
425 ".inst 0xa062c304 // st1w { z4.s-z7.s }, p8, [x24, #0x8, MUL VL]\n"
426 "addvl x24, x24, #12\n"
427 "b 27f\n"
428 "26:" // Width 3: No activation
429 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
430 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
431 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
432 ".inst 0xa061c70c // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL]\n"
433 ".inst 0xc0062c44 // mova { z4.d-z7.d }, za.d[x9, #2]\n"
434 ".inst 0xa062c304 // st1w { z4.s-z7.s }, p8, [x24, #0x8, MUL VL]\n"
435 "addvl x24, x24, #12\n"
436 "27:" // Width 3: Output done
437 "b 36f\n"
438 "28:" // Width 4
439 "mov x19, #0x3\n"
440 "mov x22, %x[A_ptr]\n"
441 "lsl x21, %x[K], #0x2\n"
442 "msub x19, x28, x19, %x[N]\n"
443 "mov x20, %x[K]\n"
444 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
445 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
446 "cbz x23, 29f\n"
447 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
448 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
449 ".inst 0xa041c6f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
450 ".inst 0xc0042e01 // mova za.d[x9, #1], { z16.d-z19.d }\n"
451 ".inst 0xa042c6fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
452 ".inst 0xc0042f82 // mova za.d[x9, #2], { z28.d-z31.d }\n"
453 ".inst 0xa043c6f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
454 ".inst 0xc0042e03 // mova za.d[x9, #3], { z16.d-z19.d }\n"
455 "addvl x23, x23, #16\n"
456 "b 30f\n"
457 "29:" // Width 4: no bias
458 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
459 "30:" // Width 4: setup done
460 "cmp x20, #0x8\n"
461 "ble 32f\n"
462 "31:" // Width 4: Multiply loop: Main loop head
463 "whilelt p1.s, XZR, x20\n"
464 "whilelt p0.s, x26, x20\n"
465 "ld1rqw { z0.s }, p1/Z, [x22]\n"
466 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
467 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
468 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
469 "uzp1 z0.h, z0.h, z0.h\n"
470 "sub x20, x20, #0x8\n"
471 "uzp1 z11.h, z11.h, z11.h\n"
472 "trn1 z0.d, z0.d, z11.d\n"
473 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
474 "cmp x20, #0x8\n"
475 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
476 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
477 "add x22, x22, #0x20\n"
478 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
479 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
480 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
481 ".inst 0xc150b21a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[0]\n"
482 "addvl x25, x25, #16\n"
483 ".inst 0xc150b39b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[0]\n"
484 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
485 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
486 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
487 ".inst 0xa042a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
488 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
489 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
490 ".inst 0xc150b59a // bfdot za.s[x9, 2], { z12.h-z15.h }, z0.h[1]\n"
491 "addvl x25, x25, #16\n"
492 ".inst 0xc150b79b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[1]\n"
493 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
494 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
495 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
496 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
497 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
498 ".inst 0xa043a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
499 ".inst 0xc150ba1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[2]\n"
500 "addvl x25, x25, #16\n"
501 ".inst 0xc150b99b // bfdot za.s[x9, 3], { z12.h-z15.h }, z0.h[2]\n"
502 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
503 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
504 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
505 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
506 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
507 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
508 ".inst 0xc150be1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[3]\n"
509 "addvl x25, x25, #16\n"
510 ".inst 0xc150bf9b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[3]\n"
511 "bgt 31b\n"
512 "32:" // Width 4: Multiply loop: Single iteration only
513 "whilelt p1.s, XZR, x20\n"
514 "whilelt p0.s, x26, x20\n"
515 "ld1rqw { z0.s }, p1/Z, [x22]\n"
516 ".inst 0x658aa800 // bfcvt z0.h, p2/M, z0.s\n"
517 "ld1rqw { z11.s }, p0/Z, [x22, #16]\n"
518 ".inst 0x658aa96b // bfcvt z11.h, p2/M, z11.s\n"
519 "uzp1 z0.h, z0.h, z0.h\n"
520 "subs x20, x20, #0x2\n"
521 "uzp1 z11.h, z11.h, z11.h\n"
522 "trn1 z0.d, z0.d, z11.d\n"
523 ".inst 0xa040a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25]\n"
524 "add x22, x22, #0x20\n"
525 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
526 ".inst 0xc150b098 // bfdot za.s[x9, 0], { z4.h-z7.h }, z0.h[0]\n"
527 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
528 ".inst 0xc150b119 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[0]\n"
529 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
530 ".inst 0xc150b21a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[0]\n"
531 "addvl x25, x25, #16\n"
532 ".inst 0xc150b39b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[0]\n"
533 "ble 33f\n"
534 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
535 "subs x20, x20, #0x2\n"
536 ".inst 0xc150b618 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[1]\n"
537 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
538 ".inst 0xc150b499 // bfdot za.s[x9, 1], { z4.h-z7.h }, z0.h[1]\n"
539 ".inst 0xa042a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
540 ".inst 0xc150b59a // bfdot za.s[x9, 2], { z12.h-z15.h }, z0.h[1]\n"
541 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
542 ".inst 0xc150b79b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[1]\n"
543 "addvl x25, x25, #16\n"
544 "ble 33f\n"
545 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
546 "subs x20, x20, #0x2\n"
547 ".inst 0xc150ba18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z0.h[2]\n"
548 ".inst 0xa041a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
549 ".inst 0xc150ba99 // bfdot za.s[x9, 1], { z20.h-z23.h }, z0.h[2]\n"
550 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
551 ".inst 0xc150ba1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[2]\n"
552 ".inst 0xa043a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
553 ".inst 0xc150b99b // bfdot za.s[x9, 3], { z12.h-z15.h }, z0.h[2]\n"
554 "addvl x25, x25, #16\n"
555 "ble 33f\n"
556 ".inst 0xa040a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25]\n"
557 ".inst 0xc150bf98 // bfdot za.s[x9, 0], { z28.h-z31.h }, z0.h[3]\n"
558 ".inst 0xa041a729 // ldnt1h { z8.h-z11.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
559 ".inst 0xc150bd19 // bfdot za.s[x9, 1], { z8.h-z11.h }, z0.h[3]\n"
560 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
561 ".inst 0xc150be1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z0.h[3]\n"
562 ".inst 0xa043a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
563 ".inst 0xc150bf9b // bfdot za.s[x9, 3], { z28.h-z31.h }, z0.h[3]\n"
564 "addvl x25, x25, #16\n"
565 "33:" // Width 4: Multiply loop: multiply skip
566 "tbz %x[flags], #1, 34f\n"
567 "add x20, %x[args_ptr], %[offset_min]\n"
568 "add x19, %x[args_ptr], %[offset_max]\n"
569 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
570 "ld1rw { z29.s }, p2/Z, [x20]\n"
571 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
572 "ld1rw { z18.s }, p2/Z, [x19]\n"
573 ".inst 0xc1b2cba8 // fclamp { z8.s-z11.s }, z29.s, z18.s\n"
574 ".inst 0xc0062c44 // mova { z4.d-z7.d }, za.d[x9, #2]\n"
575 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
576 ".inst 0xc1b2cbac // fclamp { z12.s-z15.s }, z29.s, z18.s\n"
577 ".inst 0xc0062c60 // mova { z0.d-z3.d }, za.d[x9, #3]\n"
578 ".inst 0xa061c70c // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL]\n"
579 ".inst 0xc1b2cba4 // fclamp { z4.s-z7.s }, z29.s, z18.s\n"
580 ".inst 0xa062c704 // st1w { z4.s-z7.s }, pn9.b, [x24, #0x8, MUL VL]\n"
581 ".inst 0xc1b2cba0 // fclamp { z0.s-z3.s }, z29.s, z18.s\n"
582 ".inst 0xa063c300 // st1w { z0.s-z3.s }, p8, [x24, #0xc, MUL VL]\n"
583 "addvl x24, x24, #16\n"
584 "b 35f\n"
585 "34:" // Width 4: No activation
586 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
587 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
588 ".inst 0xc0062c2c // mova { z12.d-z15.d }, za.d[x9, #1]\n"
589 ".inst 0xa061c70c // st1w { z12.s-z15.s }, pn9.b, [x24, #0x4, MUL VL]\n"
590 ".inst 0xc0062c44 // mova { z4.d-z7.d }, za.d[x9, #2]\n"
591 ".inst 0xa062c704 // st1w { z4.s-z7.s }, pn9.b, [x24, #0x8, MUL VL]\n"
592 ".inst 0xc0062c60 // mova { z0.d-z3.d }, za.d[x9, #3]\n"
593 ".inst 0xa063c300 // st1w { z0.s-z3.s }, p8, [x24, #0xc, MUL VL]\n"
594 "addvl x24, x24, #16\n"
595 "35:" // Width 4: Output done
596 "subs x27, x27, #0x4\n"
597 "sub %x[N], %x[N], x28, LSL #2\n"
598 "bgt 4b\n"
599 "36:" // Exit
600 ".inst 0xd503467f // SMSTOP\n"
601 "ptrue p2.b\n"
602 : [N] "+&r" (N)
603 : [A_ptr] "r" (A_ptr), [B_ptr] "r" (B_ptr), [K] "r" (K), [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [output_ptr] "r" (output_ptr)
604 : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
605 );
606 }
607
608 } // namespace arm_gemm
609
610 #endif // ARM_COMPUTE_ENABLE_SME2
611 #endif
612