1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24 #ifdef __ARM_FEATURE_SVE
25 #ifdef ARM_COMPUTE_ENABLE_SME2
26
27 #include "arm_gemm.hpp"
28 #include "../../utils.hpp"
29 #include "../../bfloat.hpp"
30
31 #include <cassert>
32 #include <limits>
33
34 namespace arm_gemm {
35
sme2_gemv_bf16fp32_dot_16VL(const bfloat16 * A_ptr,const bfloat16 * B_ptr,float * output_ptr,size_t N,size_t K,const float * bias,Activation act,bool)36 void sme2_gemv_bf16fp32_dot_16VL (
37 const bfloat16 *A_ptr, const bfloat16 *B_ptr, float *output_ptr,
38 size_t N, size_t K,
39 const float *bias, Activation act, bool
40 )
41 {
42 struct KernelArgs {
43 float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
44 float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
45 const bfloat16 *B_ptr = {};
46 size_t output_offset = {};
47 unsigned int input_initial_col = {};
48 } ka;
49
50 unsigned long flags=0;
51 ka.B_ptr = B_ptr;
52 switch(act.type) {
53 default:
54 case Activation::Type::None:
55 break;
56 case Activation::Type::BoundedReLU:
57 ka.maxval = static_cast<float>(act.param1);
58 /* fall through */
59 case Activation::Type::ReLU:
60 ka.minval = 0;
61 flags |= 0x2;
62 break;
63 }
64 __asm__ __volatile__(
65 "ptrue p1.b\n"
66 ".inst 0xd503477f // SMSTART ZA\n"
67 "cntw x27, ALL, MUL #4\n"
68 "add x26, %x[N], x27\n"
69 "sub x26, x26, #0x1\n"
70 "udiv x26, x26, x27\n"
71 "add x21, x26, #0x3\n"
72 "and x21, x21, #0xfffffffffffffffc\n"
73 "mul x21, x21, x27\n"
74 "mul x21, x21, %x[K]\n"
75 "mov x9, #0x0\n"
76 "mov x25, %x[B_ptr]\n"
77 "mov x24, %x[output_ptr]\n"
78 "ptrue p1.b\n"
79 ".inst 0x25207811 // ptrue pn9.b\n"
80 "lsl x21, x21, #0x1\n"
81 "mov x20, #0x1\n"
82 "1:" // RHS size check loop
83 "cmp x21, #0x200000\n"
84 "blt 2f\n"
85 "tbnz x21, #0, 3f\n"
86 "lsr x21, x21, #0x1\n"
87 "lsl x20, x20, #0x1\n"
88 "b 1b\n"
89 "2:" // RHS do prefetch
90 "lsl x19, x21, #0x26\n"
91 "sub x20, x20, #0x1\n"
92 "lsl x20, x20, #0x16\n"
93 "orr x21, x21, x19\n"
94 "orr x21, x21, x20\n"
95 ".inst 0xf8b54b3a // rprfm pldonce, x21, [x25]\n"
96 "3:" // RHS prefetch exit
97 "mov x23, %x[bias]\n"
98 "4:" // Column loop
99 "cmp x26, #0x4\n"
100 "bge 28f\n"
101 "cmp x26, #0x2\n"
102 "bgt 20f\n"
103 "beq 12f\n"
104 "mov x22, %x[A_ptr]\n"
105 "lsl x21, %x[K], #0x1\n"
106 "mov x19, %x[N]\n"
107 "mov x20, %x[K]\n"
108 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
109 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
110 "cbz x23, 5f\n"
111 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
112 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
113 "b 6f\n"
114 "5:" // Width 1: no bias
115 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
116 "6:" // Width 1: setup done
117 "cmp x20, #0x8\n"
118 "ble 8f\n"
119 "7:" // Width 1: Multiply loop: Main loop head
120 "whilelt p0.h, XZR, x20\n"
121 "ld1rqh { z10.h }, p0/Z, [x22]\n"
122 "sub x20, x20, #0x8\n"
123 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
124 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
125 "addvl x25, x25, #16\n"
126 "cmp x20, #0x8\n"
127 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
128 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
129 "addvl x25, x25, #16\n"
130 "add x22, x22, #0x10\n"
131 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
132 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
133 "addvl x25, x25, #16\n"
134 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
135 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
136 "addvl x25, x25, #16\n"
137 "bgt 7b\n"
138 "8:" // Width 1: Multiply loop: Single iteration only
139 "whilelt p0.h, XZR, x20\n"
140 "ld1rqh { z10.h }, p0/Z, [x22]\n"
141 "subs x20, x20, #0x2\n"
142 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
143 "add x22, x22, #0x10\n"
144 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
145 "addvl x25, x25, #16\n"
146 "ble 9f\n"
147 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
148 "subs x20, x20, #0x2\n"
149 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
150 "addvl x25, x25, #16\n"
151 "ble 9f\n"
152 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
153 "subs x20, x20, #0x2\n"
154 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
155 "addvl x25, x25, #16\n"
156 "ble 9f\n"
157 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
158 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
159 "addvl x25, x25, #16\n"
160 "9:" // Width 1: Multiply loop: multiply skip
161 "tbz %x[flags], #1, 10f\n"
162 "add x20, %x[args_ptr], %[offset_min]\n"
163 "add x19, %x[args_ptr], %[offset_max]\n"
164 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
165 "ld1rw { z0.s }, p1/Z, [x20]\n"
166 "ld1rw { z6.s }, p1/Z, [x19]\n"
167 ".inst 0xc1a6c808 // fclamp { z8.s-z11.s }, z0.s, z6.s\n"
168 ".inst 0xa060c308 // st1w { z8.s-z11.s }, p8, [x24]\n"
169 "addvl x24, x24, #4\n"
170 "b 11f\n"
171 "10:" // Width 1: No activation
172 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
173 ".inst 0xa060c308 // st1w { z8.s-z11.s }, p8, [x24]\n"
174 "addvl x24, x24, #4\n"
175 "11:" // Width 1: Output done
176 "b 36f\n"
177 "12:" // Width 2
178 "mov x22, %x[A_ptr]\n"
179 "lsl x21, %x[K], #0x1\n"
180 "sub x19, %x[N], x27\n"
181 "mov x20, %x[K]\n"
182 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
183 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
184 "cbz x23, 13f\n"
185 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
186 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
187 ".inst 0xa041c6e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
188 ".inst 0xc0042d01 // mova za.d[x9, #1], { z8.d-z11.d }\n"
189 "b 14f\n"
190 "13:" // Width 2: no bias
191 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
192 "14:" // Width 2: setup done
193 "cmp x20, #0x8\n"
194 "ble 16f\n"
195 "15:" // Width 2: Multiply loop: Main loop head
196 "whilelt p0.h, XZR, x20\n"
197 "ld1rqh { z10.h }, p0/Z, [x22]\n"
198 "sub x20, x20, #0x8\n"
199 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
200 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
201 "cmp x20, #0x8\n"
202 "add x22, x22, #0x10\n"
203 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
204 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
205 "addvl x25, x25, #16\n"
206 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
207 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
208 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
209 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
210 "addvl x25, x25, #16\n"
211 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
212 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
213 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
214 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
215 "addvl x25, x25, #16\n"
216 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
217 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
218 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
219 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
220 "addvl x25, x25, #16\n"
221 "bgt 15b\n"
222 "16:" // Width 2: Multiply loop: Single iteration only
223 "whilelt p0.h, XZR, x20\n"
224 "ld1rqh { z10.h }, p0/Z, [x22]\n"
225 "subs x20, x20, #0x2\n"
226 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
227 "add x22, x22, #0x10\n"
228 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
229 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
230 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
231 "addvl x25, x25, #16\n"
232 "ble 17f\n"
233 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
234 "subs x20, x20, #0x2\n"
235 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
236 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
237 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
238 "addvl x25, x25, #16\n"
239 "ble 17f\n"
240 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
241 "subs x20, x20, #0x2\n"
242 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
243 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
244 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
245 "addvl x25, x25, #16\n"
246 "ble 17f\n"
247 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
248 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
249 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
250 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
251 "addvl x25, x25, #16\n"
252 "17:" // Width 2: Multiply loop: multiply skip
253 "tbz %x[flags], #1, 18f\n"
254 "add x20, %x[args_ptr], %[offset_min]\n"
255 "add x19, %x[args_ptr], %[offset_max]\n"
256 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
257 "ld1rw { z0.s }, p1/Z, [x20]\n"
258 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
259 "ld1rw { z6.s }, p1/Z, [x19]\n"
260 ".inst 0xc1a6c808 // fclamp { z8.s-z11.s }, z0.s, z6.s\n"
261 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
262 ".inst 0xc1a6c814 // fclamp { z20.s-z23.s }, z0.s, z6.s\n"
263 ".inst 0xa061c314 // st1w { z20.s-z23.s }, p8, [x24, #0x4, MUL VL]\n"
264 "addvl x24, x24, #8\n"
265 "b 19f\n"
266 "18:" // Width 2: No activation
267 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
268 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
269 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
270 ".inst 0xa061c314 // st1w { z20.s-z23.s }, p8, [x24, #0x4, MUL VL]\n"
271 "addvl x24, x24, #8\n"
272 "19:" // Width 2: Output done
273 "b 36f\n"
274 "20:" // Width 3
275 "mov x19, #0x2\n"
276 "mov x22, %x[A_ptr]\n"
277 "lsl x21, %x[K], #0x1\n"
278 "msub x19, x27, x19, %x[N]\n"
279 "mov x20, %x[K]\n"
280 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
281 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
282 "cbz x23, 21f\n"
283 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
284 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
285 ".inst 0xa041c6e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
286 ".inst 0xc0042d01 // mova za.d[x9, #1], { z8.d-z11.d }\n"
287 ".inst 0xa042c6e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
288 ".inst 0xc0042c82 // mova za.d[x9, #2], { z4.d-z7.d }\n"
289 "b 22f\n"
290 "21:" // Width 3: no bias
291 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
292 "22:" // Width 3: setup done
293 "cmp x20, #0x8\n"
294 "ble 24f\n"
295 "23:" // Width 3: Multiply loop: Main loop head
296 "whilelt p0.h, XZR, x20\n"
297 "ld1rqh { z10.h }, p0/Z, [x22]\n"
298 "sub x20, x20, #0x8\n"
299 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
300 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
301 "cmp x20, #0x8\n"
302 "add x22, x22, #0x10\n"
303 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
304 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
305 ".inst 0xa042a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
306 ".inst 0xc15ab29a // bfdot za.s[x9, 2], { z20.h-z23.h }, z10.h[0]\n"
307 "addvl x25, x25, #16\n"
308 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
309 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
310 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
311 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
312 ".inst 0xa042a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
313 ".inst 0xc15ab71a // bfdot za.s[x9, 2], { z24.h-z27.h }, z10.h[1]\n"
314 "addvl x25, x25, #16\n"
315 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
316 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
317 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
318 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
319 ".inst 0xa042a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
320 ".inst 0xc15abb9a // bfdot za.s[x9, 2], { z28.h-z31.h }, z10.h[2]\n"
321 "addvl x25, x25, #16\n"
322 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
323 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
324 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
325 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
326 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
327 ".inst 0xc15abe1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z10.h[3]\n"
328 "addvl x25, x25, #16\n"
329 "bgt 23b\n"
330 "24:" // Width 3: Multiply loop: Single iteration only
331 "whilelt p0.h, XZR, x20\n"
332 "ld1rqh { z10.h }, p0/Z, [x22]\n"
333 "subs x20, x20, #0x2\n"
334 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
335 "add x22, x22, #0x10\n"
336 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
337 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
338 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
339 ".inst 0xa042a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
340 ".inst 0xc15ab29a // bfdot za.s[x9, 2], { z20.h-z23.h }, z10.h[0]\n"
341 "addvl x25, x25, #16\n"
342 "ble 25f\n"
343 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
344 "subs x20, x20, #0x2\n"
345 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
346 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
347 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
348 ".inst 0xa042a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
349 ".inst 0xc15ab71a // bfdot za.s[x9, 2], { z24.h-z27.h }, z10.h[1]\n"
350 "addvl x25, x25, #16\n"
351 "ble 25f\n"
352 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
353 "subs x20, x20, #0x2\n"
354 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
355 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
356 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
357 ".inst 0xa042a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
358 ".inst 0xc15abb9a // bfdot za.s[x9, 2], { z28.h-z31.h }, z10.h[2]\n"
359 "addvl x25, x25, #16\n"
360 "ble 25f\n"
361 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
362 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
363 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
364 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
365 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
366 ".inst 0xc15abe1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z10.h[3]\n"
367 "addvl x25, x25, #16\n"
368 "25:" // Width 3: Multiply loop: multiply skip
369 "tbz %x[flags], #1, 26f\n"
370 "add x20, %x[args_ptr], %[offset_min]\n"
371 "add x19, %x[args_ptr], %[offset_max]\n"
372 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
373 "ld1rw { z0.s }, p1/Z, [x20]\n"
374 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
375 "ld1rw { z6.s }, p1/Z, [x19]\n"
376 ".inst 0xc1a6c808 // fclamp { z8.s-z11.s }, z0.s, z6.s\n"
377 ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n"
378 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
379 ".inst 0xc1a6c814 // fclamp { z20.s-z23.s }, z0.s, z6.s\n"
380 ".inst 0xa061c714 // st1w { z20.s-z23.s }, pn9.b, [x24, #0x4, MUL VL]\n"
381 ".inst 0xc1a6c810 // fclamp { z16.s-z19.s }, z0.s, z6.s\n"
382 ".inst 0xa062c310 // st1w { z16.s-z19.s }, p8, [x24, #0x8, MUL VL]\n"
383 "addvl x24, x24, #12\n"
384 "b 27f\n"
385 "26:" // Width 3: No activation
386 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
387 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
388 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
389 ".inst 0xa061c714 // st1w { z20.s-z23.s }, pn9.b, [x24, #0x4, MUL VL]\n"
390 ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n"
391 ".inst 0xa062c310 // st1w { z16.s-z19.s }, p8, [x24, #0x8, MUL VL]\n"
392 "addvl x24, x24, #12\n"
393 "27:" // Width 3: Output done
394 "b 36f\n"
395 "28:" // Width 4
396 "mov x19, #0x3\n"
397 "mov x22, %x[A_ptr]\n"
398 "lsl x21, %x[K], #0x1\n"
399 "msub x19, x27, x19, %x[N]\n"
400 "mov x20, %x[K]\n"
401 ".inst 0xf8b54ad8 // rprfm pldmany, x21, [x22]\n"
402 ".inst 0x25b367f0 // whilelt p8.s, XZR, x19, VLx4\n"
403 "cbz x23, 29f\n"
404 ".inst 0xa040c6e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x23]\n"
405 ".inst 0xc0042c00 // mova za.d[x9, #0], { z0.d-z3.d }\n"
406 ".inst 0xa041c6e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
407 ".inst 0xc0042d01 // mova za.d[x9, #1], { z8.d-z11.d }\n"
408 ".inst 0xa042c6e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
409 ".inst 0xc0042c82 // mova za.d[x9, #2], { z4.d-z7.d }\n"
410 ".inst 0xa043c6f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
411 ".inst 0xc0042e03 // mova za.d[x9, #3], { z16.d-z19.d }\n"
412 "addvl x23, x23, #16\n"
413 "b 30f\n"
414 "29:" // Width 4: no bias
415 ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
416 "30:" // Width 4: setup done
417 "cmp x20, #0x8\n"
418 "ble 32f\n"
419 "31:" // Width 4: Multiply loop: Main loop head
420 "whilelt p0.h, XZR, x20\n"
421 "ld1rqh { z10.h }, p0/Z, [x22]\n"
422 "sub x20, x20, #0x8\n"
423 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
424 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
425 "cmp x20, #0x8\n"
426 "add x22, x22, #0x10\n"
427 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
428 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
429 ".inst 0xa042a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
430 ".inst 0xc15ab29a // bfdot za.s[x9, 2], { z20.h-z23.h }, z10.h[0]\n"
431 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
432 ".inst 0xc15ab21b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[0]\n"
433 "addvl x25, x25, #16\n"
434 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
435 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
436 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
437 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
438 ".inst 0xa042a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
439 ".inst 0xc15ab71a // bfdot za.s[x9, 2], { z24.h-z27.h }, z10.h[1]\n"
440 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
441 ".inst 0xc15ab61b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[1]\n"
442 "addvl x25, x25, #16\n"
443 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
444 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
445 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
446 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
447 ".inst 0xa042a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
448 ".inst 0xc15abb9a // bfdot za.s[x9, 2], { z28.h-z31.h }, z10.h[2]\n"
449 ".inst 0xa043a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
450 ".inst 0xc15aba9b // bfdot za.s[x9, 3], { z20.h-z23.h }, z10.h[2]\n"
451 "addvl x25, x25, #16\n"
452 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
453 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
454 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
455 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
456 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
457 ".inst 0xc15abe1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z10.h[3]\n"
458 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
459 ".inst 0xc15abe1b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[3]\n"
460 "addvl x25, x25, #16\n"
461 "bgt 31b\n"
462 "32:" // Width 4: Multiply loop: Single iteration only
463 "whilelt p0.h, XZR, x20\n"
464 "ld1rqh { z10.h }, p0/Z, [x22]\n"
465 "subs x20, x20, #0x2\n"
466 ".inst 0xa040a721 // ldnt1h { z0.h-z3.h }, pn9.b/Z, [x25]\n"
467 "add x22, x22, #0x10\n"
468 ".inst 0xc15ab018 // bfdot za.s[x9, 0], { z0.h-z3.h }, z10.h[0]\n"
469 ".inst 0xa041a725 // ldnt1h { z4.h-z7.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
470 ".inst 0xc15ab099 // bfdot za.s[x9, 1], { z4.h-z7.h }, z10.h[0]\n"
471 ".inst 0xa042a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
472 ".inst 0xc15ab29a // bfdot za.s[x9, 2], { z20.h-z23.h }, z10.h[0]\n"
473 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
474 ".inst 0xc15ab21b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[0]\n"
475 "addvl x25, x25, #16\n"
476 "ble 33f\n"
477 ".inst 0xa040a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25]\n"
478 "subs x20, x20, #0x2\n"
479 ".inst 0xc15ab718 // bfdot za.s[x9, 0], { z24.h-z27.h }, z10.h[1]\n"
480 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
481 ".inst 0xc15ab619 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[1]\n"
482 ".inst 0xa042a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
483 ".inst 0xc15ab71a // bfdot za.s[x9, 2], { z24.h-z27.h }, z10.h[1]\n"
484 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
485 ".inst 0xc15ab61b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[1]\n"
486 "addvl x25, x25, #16\n"
487 "ble 33f\n"
488 ".inst 0xa040a72d // ldnt1h { z12.h-z15.h }, pn9.b/Z, [x25]\n"
489 "subs x20, x20, #0x2\n"
490 ".inst 0xc15ab998 // bfdot za.s[x9, 0], { z12.h-z15.h }, z10.h[2]\n"
491 ".inst 0xa041a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
492 ".inst 0xc15aba19 // bfdot za.s[x9, 1], { z16.h-z19.h }, z10.h[2]\n"
493 ".inst 0xa042a73d // ldnt1h { z28.h-z31.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
494 ".inst 0xc15abb9a // bfdot za.s[x9, 2], { z28.h-z31.h }, z10.h[2]\n"
495 ".inst 0xa043a735 // ldnt1h { z20.h-z23.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
496 ".inst 0xc15aba9b // bfdot za.s[x9, 3], { z20.h-z23.h }, z10.h[2]\n"
497 "addvl x25, x25, #16\n"
498 "ble 33f\n"
499 ".inst 0xa040a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25]\n"
500 ".inst 0xc15abe18 // bfdot za.s[x9, 0], { z16.h-z19.h }, z10.h[3]\n"
501 ".inst 0xa041a739 // ldnt1h { z24.h-z27.h }, pn9.b/Z, [x25, #0x4, MUL VL]\n"
502 ".inst 0xc15abf19 // bfdot za.s[x9, 1], { z24.h-z27.h }, z10.h[3]\n"
503 ".inst 0xa042a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0x8, MUL VL]\n"
504 ".inst 0xc15abe1a // bfdot za.s[x9, 2], { z16.h-z19.h }, z10.h[3]\n"
505 ".inst 0xa043a731 // ldnt1h { z16.h-z19.h }, pn9.b/Z, [x25, #0xc, MUL VL]\n"
506 ".inst 0xc15abe1b // bfdot za.s[x9, 3], { z16.h-z19.h }, z10.h[3]\n"
507 "addvl x25, x25, #16\n"
508 "33:" // Width 4: Multiply loop: multiply skip
509 "tbz %x[flags], #1, 34f\n"
510 "add x20, %x[args_ptr], %[offset_min]\n"
511 "add x19, %x[args_ptr], %[offset_max]\n"
512 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
513 "ld1rw { z0.s }, p1/Z, [x20]\n"
514 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
515 "ld1rw { z6.s }, p1/Z, [x19]\n"
516 ".inst 0xc1a6c808 // fclamp { z8.s-z11.s }, z0.s, z6.s\n"
517 ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n"
518 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
519 ".inst 0xc1a6c814 // fclamp { z20.s-z23.s }, z0.s, z6.s\n"
520 ".inst 0xc0062c78 // mova { z24.d-z27.d }, za.d[x9, #3]\n"
521 ".inst 0xa061c714 // st1w { z20.s-z23.s }, pn9.b, [x24, #0x4, MUL VL]\n"
522 ".inst 0xc1a6c810 // fclamp { z16.s-z19.s }, z0.s, z6.s\n"
523 ".inst 0xa062c710 // st1w { z16.s-z19.s }, pn9.b, [x24, #0x8, MUL VL]\n"
524 ".inst 0xc1a6c818 // fclamp { z24.s-z27.s }, z0.s, z6.s\n"
525 ".inst 0xa063c318 // st1w { z24.s-z27.s }, p8, [x24, #0xc, MUL VL]\n"
526 "addvl x24, x24, #16\n"
527 "b 35f\n"
528 "34:" // Width 4: No activation
529 ".inst 0xc0062c08 // mova { z8.d-z11.d }, za.d[x9, #0]\n"
530 ".inst 0xa060c708 // st1w { z8.s-z11.s }, pn9.b, [x24]\n"
531 ".inst 0xc0062c34 // mova { z20.d-z23.d }, za.d[x9, #1]\n"
532 ".inst 0xa061c714 // st1w { z20.s-z23.s }, pn9.b, [x24, #0x4, MUL VL]\n"
533 ".inst 0xc0062c50 // mova { z16.d-z19.d }, za.d[x9, #2]\n"
534 ".inst 0xa062c710 // st1w { z16.s-z19.s }, pn9.b, [x24, #0x8, MUL VL]\n"
535 ".inst 0xc0062c78 // mova { z24.d-z27.d }, za.d[x9, #3]\n"
536 ".inst 0xa063c318 // st1w { z24.s-z27.s }, p8, [x24, #0xc, MUL VL]\n"
537 "addvl x24, x24, #16\n"
538 "35:" // Width 4: Output done
539 "subs x26, x26, #0x4\n"
540 "sub %x[N], %x[N], x27, LSL #2\n"
541 "bgt 4b\n"
542 "36:" // Exit
543 ".inst 0xd503467f // SMSTOP\n"
544 "ptrue p1.b\n"
545 : [N] "+&r" (N)
546 : [A_ptr] "r" (A_ptr), [B_ptr] "r" (B_ptr), [K] "r" (K), [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [output_ptr] "r" (output_ptr)
547 : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
548 );
549 }
550
551 } // namespace arm_gemm
552
553 #endif // ARM_COMPUTE_ENABLE_SME2
554 #endif
555