1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "ruy/kernel_arm.h"
17 #include "ruy/opt_set.h"
18 #include "ruy/platform.h"
19 #include "ruy/profiler/instrumentation.h"
20
21 namespace ruy {
22
23 #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
24
25 #define RUY_ASM_LABEL_STORE_UINT8 91
26 #define RUY_ASM_LABEL_STORE_INT8 92
27 #define RUY_ASM_LABEL_STORE_INT16 93
28 #define RUY_ASM_LABEL_STORE_INT32 94
29 #define RUY_ASM_LABEL_AFTER_STORE 99
30
31 #define RUY_OFFSET_LHS_BASE_PTR 0
32 #define RUY_OFFSET_RHS_BASE_PTR 4
33 #define RUY_OFFSET_DST_BASE_PTR 8
34 #define RUY_OFFSET_BIAS 12
35 #define RUY_OFFSET_START_ROW 16
36 #define RUY_OFFSET_START_COL 20
37 #define RUY_OFFSET_LAST_ROW 24
38 #define RUY_OFFSET_LAST_COL 28
39 #define RUY_OFFSET_DST_ROWS 32
40 #define RUY_OFFSET_DST_COLS 36
41 #define RUY_OFFSET_LHS_STRIDE 40
42 #define RUY_OFFSET_RHS_STRIDE 44
43 #define RUY_OFFSET_DST_STRIDE 48
44 #define RUY_OFFSET_DEPTH 52
45 #define RUY_OFFSET_CLAMP_MIN 56
46 #define RUY_OFFSET_CLAMP_MAX 60
47 #define RUY_OFFSET_FLAGS 64
48
49 #define RUY_STACK_OFFSET_SIZE 96
50 #define RUY_STACK_OFFSET_DST_COL_PTR 0
51 #define RUY_STACK_OFFSET_DST_PTR 16
52 #define RUY_STACK_OFFSET_ROW 32
53 #define RUY_STACK_OFFSET_COL 48
54 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
55 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
56
57 template <typename Params>
CheckOffsetsInKernelParamsFloat32(const Params &)58 void CheckOffsetsInKernelParamsFloat32(const Params&) {
59 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
60 static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
61 static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
62 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
63 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
64 static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
65 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
66 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
67 static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
68 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
69 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
70 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
71 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
72 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
73 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
74 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
75 }
76
77 // Float kernel for ARM32 out-of-order cores.
78 // Just like Float 64 version, except accumulate in to 8x4 block to only
79 // use 16 128-bit NEON registers. This is a "first pass" kernel and not
80 // tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
KernelFloat32Neon(const KernelParamsFloat<8,4> & params)81 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) {
82 CheckOffsetsInKernelParamsFloat32(params);
83 profiler::ScopeLabel label("Kernel (kNeon)");
84
85 const float* lhs_ptr = params.lhs_base_ptr;
86 const float* rhs_ptr = params.rhs_base_ptr;
87 // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
88 // each composed of two 64-bit "d" registers. The asm kernel below has the
89 // following NEON register allocation:
90 // Registers q3 -- q10 are accumulators. During accumulation,
91 // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
92 // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
93 // of RHS, like this:
94
95 // Register layout in "q" registers:
96 // RHS 1x4 block
97 // /--------------------------|
98 // |q2.s[0] ... q2.s[3] |
99 // \--------------------------/
100 // LHS 8x1 block
101 // /---------------------\ /--------------------------|
102 // | q0.s[0] | | q3.s[0] ... q9.s[0] |
103 // | ... | | ... ... |
104 // | q0.s[3] | | q3.s[3] q9.s[3] |
105 // | q1.s[0] | | q4.s[0] q10.s[0] |
106 // | ... | | ... ... ... |
107 // | q1.s[3] | | q4.s[3] .. q10.s[3] |
108 // \---------------------/ \--------------------------/
109 // accumulators 8x4 block
110 // q11, q14, q15 currently unused. q12 and q13 are used to load
111 // parameters used for the post-accumulation part of the kernel.
112 // For completeness, here is the register layout in "d" registers:
113 // RHS 1x4 block
114 // /--------------------------|
115 // |d4[0] ... d5[1] |
116 // \--------------------------/
117 // LHS 8x1 block
118 // /---------------------\ /--------------------------|
119 // | d0[0] | | d6[0] ... d18[0] |
120 // | ... | | ... ... |
121 // | d1[1] | | d7[1] d19[1] |
122 // | d2[0] | | d8[0] d20[0] |
123 // | ... | | ... ... ... |
124 // | d3[1] | | d9[1] ... d21[1] |
125 // \---------------------/ \--------------------------/
126 // accumulators 8x4 block
127 asm volatile(
128 #define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
129
130 // clang-format off
131
132 // Load the first 32 bytes of LHS and RHS data.
133 // Load q0, q1
134 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
135 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
136 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
137 // Load q2
138 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
139 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
140
141 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
142
143 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
144 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
145
146 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
147 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
148
149 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
150 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
151
152 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
153 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
154
155 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
156 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
157
158 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
159 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
160 // Clear accumulators.
161 RUY_MAKE_ZERO(q3)
162 RUY_MAKE_ZERO(q4)
163 RUY_MAKE_ZERO(q5)
164 RUY_MAKE_ZERO(q6)
165 RUY_MAKE_ZERO(q7)
166 RUY_MAKE_ZERO(q8)
167 RUY_MAKE_ZERO(q9)
168 RUY_MAKE_ZERO(q10)
169
170 // r1 is the number of levels of depth that we have already loaded
171 // LHS and RHS data for. Corresponding to the initial ld1 instructions
172 // above, this is currently 1.
173 "mov r1, #1\n"
174
175 // Main loop of the whole GEMM, over rows and columns of the
176 // destination matrix.
177 "1:\n"
178
179 // Accumulation loop
180 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
181 "cmp r1, r2\n"
182 "beq 79f\n"
183
184 "2:\n"
185
186 "vmla.f32 q3, q0, d4[0]\n"
187 "vmla.f32 q5, q0, d4[1]\n"
188 "vmla.f32 q7, q0, d5[0]\n"
189 "vmla.f32 q9, q0, d5[1]\n"
190 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
191
192 "vmla.f32 q4, q1, d4[0]\n"
193 "vmla.f32 q6, q1, d4[1]\n"
194 "vmla.f32 q8, q1, d5[0]\n"
195 "vmla.f32 q10, q1, d5[1]\n"
196 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
197 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
198 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
199 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
200
201 "add r1, r1, #1\n"
202 "cmp r1, r2\n"
203
204 "blt 2b\n"
205
206 "79:\n"
207
208 // End of the inner loop on depth. Now perform the remaining
209 // multiply-adds of the last level of depth, for which the LHS
210 // and RHS data is already loaded.
211
212 "vmla.f32 q3, q0, d4[0]\n"
213 "vmla.f32 q5, q0, d4[1]\n"
214 "vmla.f32 q7, q0, d5[0]\n"
215 "vmla.f32 q9, q0, d5[1]\n"
216
217 "vmla.f32 q4, q1, d4[0]\n"
218 "vmla.f32 q6, q1, d4[1]\n"
219 "vmla.f32 q8, q1, d5[0]\n"
220 "vmla.f32 q10, q1, d5[1]\n"
221
222 // End of accumulation. The registers q3 -- q10 contain the final
223 // float32 accumulator values of the current 8x8 destination block.
224 // We now have to compute the final values from these accumulators
225 // and advance to the next 8x8 block. We intertwine
226 // these two aspects whenever possible for optimal pipelining, both
227 // at the data flow level (prefetch data for next block as early as
228 // possible) and instruction pipelining level (some of the next-block
229 // work can dual-issue with some of the final work on the current
230 // block).
231
232 // Logic to advance to the next block in preparation for the next
233 // iteration of the main loop. For now, we only want to compute
234 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
235 // not yet ready to update the values of row and col, as we still need
236 // the current values for the rest of the work on the current block.
237
238 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
239 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
240 "cmp r1, r3\n" // Have we finished the last row?
241
242 "bge 4f\n" // If finished last row, go to 4
243 // Not finished last row: then advance to next row.
244 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
245 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
246 "add r4, r4, r1, lsl #3\n"
247 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
248 "b 5f\n"
249 "4:\n" // Finished last row...
250 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
251 // Go back to first row
252 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
253 // Now we need to advance to the next column. If we already
254 // finished the last column, then in principle we are done, however
255 // we can't just return here, as we need to allow the end work of the
256 // current block to complete. The good news is that at this point it
257 // doesn't matter what data we load for the next column, since
258 // we will exit from the main loop below before actually storing
259 // anything computed from that data.
260 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
261 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
262 "cmp r8, r4\n" // Have we finished the last column?
263 "bge 5f\n" // If yes, just carry on without updating the column pointer.
264 // Not finished last column: then advance to next column.
265 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
266 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
267 "add r10, r10, r1, lsl #2\n"
268 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
269 "5:\n"
270
271 // Set the LHS and RHS data pointers to the start of the columns just
272 // computed.
273 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
274 "mov %[lhs_ptr], r4\n"
275 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
276 "mov %[rhs_ptr], r5\n"
277
278 // Load some parameters needed for the end work on current block.
279 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
280 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
281
282 // Let r8 be stack offset of the row or column variable, whichever
283 // is the channel index.
284 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
285 "ite eq\n"
286 "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
287 "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
288 // Let r8 be the channel index.
289 "ldr r8, [sp, r8]\n"
290 // Compute the bias pointer, by conditionally using the channel index
291 // (r8) as offset into bias buffer (r1).
292 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
293 "it ne\n"
294 "addne r1, r1, r8, lsl #2\n"
295
296 // Load 4 bias values. When the channel dimension is rows, we will load
297 // another 4 bias values just before performing the bias addition below,
298 // as this kernel has a 8x4 rectangular shape.
299 "vld1.32 {d24, d25}, [r1]!\n"
300
301 // Now that we know what LHS and RHS data the next iteration of the
302 // main loop will need to load, we start loading the first 32 bytes of
303 // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
304 // in the rest of the work on the current block.
305 // Load q0, q1
306 "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
307 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
308 // Load q2
309 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
310 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
311
312 // Perform the bias-addition.
313 // Jump based on channel dimension.
314 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
315 "bne 6f\n"
316 // Case where channels are rows.
317 // Load the remaining 4 bias values, since we're on the width-8 side
318 // of this 8x4 kernel.
319 "vld1.32 {d26, d27}, [r1]\n"
320 "vadd.f32 q3, q3, q12\n"
321 "vadd.f32 q5, q5, q12\n"
322 "vadd.f32 q7, q7, q12\n"
323 "vadd.f32 q9, q9, q12\n"
324 "vadd.f32 q4, q4, q13\n"
325 "vadd.f32 q6, q6, q13\n"
326 "vadd.f32 q8, q8, q13\n"
327 "vadd.f32 q10, q10, q13\n"
328 "b 7f\n"
329
330 "6:\n"
331 // Case where channels are columns.
332 "vdup.32 q11, d24[0]\n"
333 "vdup.32 q13, d24[1]\n"
334 "vdup.32 q14, d25[0]\n"
335 "vdup.32 q15, d25[1]\n"
336 "vadd.f32 q3, q3, q11\n"
337 "vadd.f32 q4, q4, q11\n"
338 "vadd.f32 q5, q5, q13\n"
339 "vadd.f32 q6, q6, q13\n"
340 "vadd.f32 q7, q7, q14\n"
341 "vadd.f32 q8, q8, q14\n"
342 "vadd.f32 q9, q9, q15\n"
343 "vadd.f32 q10, q10, q15\n"
344 "7:\n"
345
346 // Load the clamp_min, clamp_max bounds
347 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
348 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
349 "vdup.32 q12, r2\n" // clamp_min
350 "vdup.32 q13, r3\n" // clamp_max
351
352 // Apply the clamp_min bound
353 "vmax.f32 q3, q3, q12\n"
354 "vmax.f32 q4, q4, q12\n"
355 "vmax.f32 q5, q5, q12\n"
356 "vmax.f32 q6, q6, q12\n"
357 "vmax.f32 q7, q7, q12\n"
358 "vmax.f32 q8, q8, q12\n"
359 "vmax.f32 q9, q9, q12\n"
360 "vmax.f32 q10, q10, q12\n"
361
362 // Apply the clamp_max bound
363 "vmin.f32 q3, q3, q13\n"
364 "vmin.f32 q4, q4, q13\n"
365 "vmin.f32 q5, q5, q13\n"
366 "vmin.f32 q6, q6, q13\n"
367 "vmin.f32 q7, q7, q13\n"
368 "vmin.f32 q8, q8, q13\n"
369 "vmin.f32 q9, q9, q13\n"
370 "vmin.f32 q10, q10, q13\n"
371
372 // Compute how much of the 8x4 block of destination values that
373 // we have computed, fit in the destination matrix. Typically, all of
374 // it fits, but when the destination matrix shape is not a multiple
375 // of 8x4, there are some 8x8 blocks along the boundaries that do
376 // not fit entirely.
377 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
378 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
379 "sub r1, r1, r8\n"
380
381 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
382 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
383 "sub r2, r2, r4\n"
384 "mov r3, #8\n"
385 "mov r5, #4\n"
386 "cmp r1, #8\n"
387 // Compute r1 = how many rows of the 8x4 block fit
388 "it gt\n"
389 "movgt r1, r3\n"
390 "cmp r2, #4\n"
391 // Compute r2 = how many cols of the 8x4 block fit
392 "it gt\n"
393 "movgt r2, r5\n"
394
395 // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
396 "cmp r1, r3\n"
397 "it eq\n"
398 "cmpeq r2, r5\n"
399 // Yes, all of the 8x4 block fits, go to fast path.
400 "beq 30f\n"
401 // Not all of the 8x4 block fits.
402 // Set (r3 address, r4 stride) to write to dst_tmp_buf
403 "mov r3, %[dst_tmp_buf]\n"
404 "mov r4, #32\n"
405 "b 31f\n"
406 "30:\n"
407 // Yes, all of the 8x4 block fits.
408 // Set (r3 address, r4 stride) to write directly to destination matrix.
409 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
410 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
411 "mov r4, r5\n"
412 "31:\n"
413
414 // Write our float values to the destination described by
415 // (r3 address, r4 stride)
416 "vst1.32 {d6, d7, d8, d9}, [r3]\n"
417 "add r3, r3, r4\n"
418 RUY_MAKE_ZERO(q3)
419 RUY_MAKE_ZERO(q4)
420 "vst1.32 {d10, d11, d12, d13}, [r3]\n"
421 "add r3, r3, r4\n"
422 RUY_MAKE_ZERO(q5)
423 RUY_MAKE_ZERO(q6)
424 "vst1.32 {d14, d15, d16, d17}, [r3]\n"
425 "add r3, r3, r4\n"
426 RUY_MAKE_ZERO(q7)
427 RUY_MAKE_ZERO(q8)
428 "vst1.32 {d18, d19, d20, d21}, [r3]\n"
429 "add r3, r3, r4\n"
430 RUY_MAKE_ZERO(q9)
431 RUY_MAKE_ZERO(q10)
432
433 // If all of the 8x4 block fits, we just finished writing it to the
434 // destination, so we skip the next part.
435 "beq 41f\n"
436 // Not all of the 8x8 block fits in the destination matrix. We just
437 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
438 // it to copy into the destination matrix the part that fits.
439 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
440 "mov r3, %[dst_tmp_buf]\n"
441 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
442 "mov r6, #0\n"
443 "50:\n"
444 "mov r5, #0\n"
445 "51:\n"
446 "ldr r10, [r3, r5, lsl #2]\n"
447 "str r10, [r4, r5, lsl #2]\n"
448 "add r5, r5, #1\n"
449 "cmp r5, r1\n"
450 "blt 51b\n"
451 "add r6, r6, #1\n"
452 "add r3, r3, #32\n"
453 "add r4, r4, r8\n"
454 // r2 = how many cols of the 8x4 block fit
455 "cmp r6, r2\n"
456 "blt 50b\n"
457 "41:\n"
458 // Load dst_ptr, increment, and write back.
459 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
460 "add r4, r4, #32\n"
461 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
462 // At this point we have completely finished writing values to the
463 // destination matrix for the current block.
464
465 // Reload some params --- we had used r3, r5, r10 for a few other things
466 // since the last time we had loaded them.
467 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
468 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
469 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
470
471 // Move to the next block of the destination matrix, for the next iter
472 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
473 // been updated earlier.
474 // Have we reached the end row?
475 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
476 "cmp r8, r3\n"
477
478 "beq 20f\n" // yes, end row.
479 // Not end row. Move to the next row.
480 "add r8, r8, #8\n"
481 // Store new value of row
482 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
483
484 "b 21f\n"
485 "20:\n"
486 // Was already at end row.
487 // Move back to first row.
488 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
489 // Move to the next column.
490 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
491 "add r4, r4, #4\n"
492 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
493
494 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
495 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
496 // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
497 "add r1, r1, r8, lsl #2\n"
498 // Store dst_col_ptr
499 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
500 // Store dst_ptr
501 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
502 "21:\n"
503
504 // Main loop exit condition: have we hit the end column?
505 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
506 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
507 "cmp r8, r4\n"
508
509 // r1 is the number of levels of depth that we have already loaded
510 // LHS and RHS data for. Corresponding to the initial ld1 instructions
511 // above, this is currently 1.
512 "mov r1, #1\n"
513
514 "ble 1b\n"
515
516 // Restore stack pointer.
517 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
518
519 // clang-format on
520 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
521 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
522 // Clobber list must specify q registers (and not their constituent
523 // d registers). There is a (currently unexplained) slowdown if
524 // d registers are listed in the clobbers list.
525 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
526 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
527 "q9", "q10", "q12", "q13");
528 }
529
530 #undef RUY_MAKE_ZERO
531 #undef RUY_STACK_OFFSET_SIZE
532 #undef RUY_STACK_OFFSET_DST_COL_PTR
533 #undef RUY_STACK_OFFSET_DST_PTR
534 #undef RUY_STACK_OFFSET_ROW
535 #undef RUY_STACK_OFFSET_COL
536 #undef RUY_STACK_OFFSET_LHS_COL_PTR
537 #undef RUY_STACK_OFFSET_RHS_COL_PTR
538
539 #undef RUY_OFFSET_LHS_BASE_PTR
540 #undef RUY_OFFSET_RHS_BASE_PTR
541 #undef RUY_OFFSET_DST_BASE_PTR
542 #undef RUY_OFFSET_BIAS
543 #undef RUY_OFFSET_START_ROW
544 #undef RUY_OFFSET_START_COL
545 #undef RUY_OFFSET_LAST_ROW
546 #undef RUY_OFFSET_LAST_COL
547 #undef RUY_OFFSET_DST_ROWS
548 #undef RUY_OFFSET_DST_COLS
549 #undef RUY_OFFSET_LHS_STRIDE
550 #undef RUY_OFFSET_RHS_STRIDE
551 #undef RUY_OFFSET_DST_STRIDE
552 #undef RUY_OFFSET_DEPTH
553 #undef RUY_OFFSET_CLAMP_MIN
554 #undef RUY_OFFSET_CLAMP_MAX
555 #undef RUY_OFFSET_FLAGS
556
557 #define RUY_OFFSET_BIAS 0
558 #define RUY_OFFSET_LHS_SUMS 4
559 #define RUY_OFFSET_RHS_SUMS 8
560 #define RUY_OFFSET_LHS_BASE_PTR 12
561 #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
562 #define RUY_OFFSET_MULTIPLIER_EXPONENT 20
563 #define RUY_OFFSET_RHS_BASE_PTR 24
564 #define RUY_OFFSET_DST_BASE_PTR 28
565 #define RUY_OFFSET_LHS_ZERO_POINT 32
566 #define RUY_OFFSET_RHS_ZERO_POINT 36
567 #define RUY_OFFSET_DST_ZERO_POINT 40
568 #define RUY_OFFSET_PROD_ZP_DEPTH 44
569 #define RUY_OFFSET_START_ROW 48
570 #define RUY_OFFSET_START_COL 52
571 #define RUY_OFFSET_LAST_ROW 56
572 #define RUY_OFFSET_LAST_COL 60
573 #define RUY_OFFSET_DST_ROWS 64
574 #define RUY_OFFSET_DST_COLS 68
575 #define RUY_OFFSET_LHS_STRIDE 72
576 #define RUY_OFFSET_RHS_STRIDE 76
577 #define RUY_OFFSET_DST_STRIDE 80
578 #define RUY_OFFSET_DEPTH 84
579 #define RUY_OFFSET_CLAMP_MIN 88
580 #define RUY_OFFSET_CLAMP_MAX 92
581 #define RUY_OFFSET_FLAGS 96
582 #define RUY_OFFSET_DST_TYPE_ID 97
583
584 #define RUY_STACK_OFFSET_SIZE 96
585 #define RUY_STACK_OFFSET_DST_COL_PTR 0
586 #define RUY_STACK_OFFSET_DST_PTR 16
587 #define RUY_STACK_OFFSET_ROW 32
588 #define RUY_STACK_OFFSET_COL 48
589 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
590 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
591
592 template <typename Params>
CheckOffsetsInKernelParams8bit(const Params &)593 void CheckOffsetsInKernelParams8bit(const Params&) {
594 static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
595 "");
596 static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
597 "");
598 static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
599 "");
600 static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
601 "");
602 static_assert(offsetof(Params, multiplier_fixedpoint) ==
603 RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
604 "");
605 static_assert(
606 offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
607 "");
608 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
609 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
610 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
611 static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
612 static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
613 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
614 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
615 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
616 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
617 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
618 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
619 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
620 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
621 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
622 }
623
624 // Fast-int8 kernel, ported from ARM 64 version.
625 // Relevant target CPUs for this kernel include Krait 400 and A9,
626 // since these are 32-bit, out-of-order CPUs.
Kernel8bitNeon(const KernelParams8bit<4,2> & params)627 void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
628 profiler::ScopeLabel label("Kernel (kNeon)");
629
630 CheckOffsetsInKernelParams8bit(params);
631
632 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
633 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
634 const std::int8_t* lhs_ptr = lhs_col_ptr;
635 const std::int8_t* rhs_ptr = rhs_col_ptr;
636
637 // The asm kernel below has the following NEON register allocation:
638 //
639 // q6 - q13 are 128-bit (4x32b) accumulators.
640 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
641 // d8 -- d11 from RHS:
642 // int8 RHS 16x2 block
643 // /-----------------------------|
644 // |d8.b[0-7] ..... d10.b[0-7]|
645 // | ... ... |
646 // |d9.b[0-7] ..... d11.b[0-7]|
647 // \-----------------------------/
648 // int8 LHS 4x16 block
649 // /------------------------\ /-----------------------------|
650 // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
651 // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
652 // (Reload d0, d1, d2, d3)
653 // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
654 // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
655 // \------------------------/ \-----------------------------/
656 // 128-bit accumulators 4x2 block
657 //
658 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
659 // optimization for this kernel.
660 asm volatile(
661 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
662
663 // clang-format off
664
665 // Load the first 64 bytes of LHS and RHS data.
666 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
667 // Clear accumulators.
668 RUY_MAKE_ZERO(q6)
669 RUY_MAKE_ZERO(q7)
670 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
671 RUY_MAKE_ZERO(q8)
672 RUY_MAKE_ZERO(q9)
673 RUY_MAKE_ZERO(q10)
674 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
675 RUY_MAKE_ZERO(q11)
676 "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
677
678 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
679
680 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
681 RUY_MAKE_ZERO(q12)
682 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
683
684 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
685 RUY_MAKE_ZERO(q13)
686 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
687
688 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
689 RUY_MAKE_ZERO(q14)
690 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
691
692 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
693 RUY_MAKE_ZERO(q15)
694 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
695
696 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
697 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
698
699 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
700 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
701
702
703 // r1 is the number of levels of depth that we have already loaded
704 // LHS and RHS data for. Corresponding to the initial ld1 instructions
705 // above, this is currently 16.
706 "mov r1, #16\n"
707
708 // Main loop of the whole GEMM, over rows and columns of the
709 // destination matrix.
710 "1:\n"
711
712 // r1 is how many levels of depth we have already loaded
713 // data for, r10 is the total depth.
714 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
715 "cmp r1, r10\n"
716 "beq 79f\n"
717
718 "2:\n"
719
720 // Mult, mult-acc in to q14, q15, q2, q3
721 "vmull.s8 q14, d0, d8\n"
722 "vmull.s8 q2, d0, d10\n"
723
724 "vmull.s8 q15, d2, d8\n"
725 "vmull.s8 q3, d2, d10\n"
726
727 "vmlal.s8 q14, d1, d9\n"
728 "vmlal.s8 q2, d1, d11\n"
729 "vmlal.s8 q15, d3, d9\n"
730 "vmlal.s8 q3, d3, d11\n"
731 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
732
733 // Then pairwise accumulate in to q6, q7, q10, q11
734 "vpadal.s16 q6, q14\n"
735 "vpadal.s16 q7, q15\n"
736 "vpadal.s16 q10, q2\n"
737 "vpadal.s16 q11, q3\n"
738
739 // Mult, mult-acc in to q14, q15, q2, q3
740 "vmull.s8 q14, d0, d8\n"
741 "vmull.s8 q2, d0, d10\n"
742
743 "vmull.s8 q15, d2, d8\n"
744 "vmull.s8 q3, d2, d10\n"
745
746 "vmlal.s8 q14, d1, d9\n"
747 "vmlal.s8 q2, d1, d11\n"
748 "vmlal.s8 q15, d3, d9\n"
749 "vmlal.s8 q3, d3, d11\n"
750 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
751
752 // Then pairwise accumulate in to q8, q9, q12, q13
753 "vpadal.s16 q8, q14\n"
754 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
755 "vpadal.s16 q9, q15\n"
756 "vpadal.s16 q12, q2\n"
757 "vpadal.s16 q13, q3\n"
758
759 // Prefetch the next 64 bytes of LHS and RHS data.
760 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
761 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
762
763 // Each iteration of this loop advances by 16 levels of depth.
764 "add r1, r1, #16\n"
765
766 // Loop termination condition
767 "cmp r1, r10\n"
768
769 "blt 2b\n"
770
771 "79:\n"
772
773 // Mult, mult-acc in to q14, q15, q2, q3
774 "vmull.s8 q14, d0, d8\n"
775 "vmull.s8 q2, d0, d10\n"
776
777 "vmull.s8 q15, d2, d8\n"
778 "vmull.s8 q3, d2, d10\n"
779
780 "vmlal.s8 q14, d1, d9\n"
781 "vmlal.s8 q2, d1, d11\n"
782 "vmlal.s8 q15, d3, d9\n"
783 "vmlal.s8 q3, d3, d11\n"
784 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
785
786 // Then pairwise accumulate in to q6, q7, q10, q11
787 "vpadal.s16 q6, q14\n"
788 "vpadal.s16 q7, q15\n"
789 "vpadal.s16 q10, q2\n"
790 "vpadal.s16 q11, q3\n"
791
792 // Mult, mult-acc in to q14, q15, q2, q3
793 "vmull.s8 q14, d0, d8\n"
794 "vmull.s8 q2, d0, d10\n"
795
796 "vmull.s8 q15, d2, d8\n"
797 "vmull.s8 q3, d2, d10\n"
798
799 "vmlal.s8 q14, d1, d9\n"
800 "vmlal.s8 q2, d1, d11\n"
801 "vmlal.s8 q15, d3, d9\n"
802 "vmlal.s8 q3, d3, d11\n"
803
804 // Then pairwise accumulate in to q8, q9, q12, q13
805 "vpadal.s16 q8, q14\n"
806 "vpadal.s16 q9, q15\n"
807 "vpadal.s16 q12, q2\n"
808 "vpadal.s16 q13, q3\n"
809
810
811 // All accumulation over depth done. q6 - q13 contain the 4x32b
812 // accumulators for the 4x2 final matrix.
813 // We now have to compute the final 8-bit values from these int32
814 // accumulators, and advance to the next 4x2 block. We intertwine
815 // these two aspects whenever possible for optimal pipelining, both
816 // at the data flow level (prefetch data for next block as early as
817 // possible) and instruction pipelining level (some of the next-block
818 // work can dual-issue with some of the final work on the current
819 // block).
820
821 // q6-q13 now contain 4 x 32b
822 "vpadd.i32 d0, d12, d13\n"
823 "vpadd.i32 d1, d14, d15\n"
824 "vpadd.i32 d2, d16, d17\n"
825 "vpadd.i32 d3, d18, d19\n"
826 "vpadd.i32 d4, d20, d21\n"
827 "vpadd.i32 d5, d22, d23\n"
828 "vpadd.i32 d6, d24, d25\n"
829 "vpadd.i32 d7, d26, d27\n"
830
831 // d0-d7 each contain 2 x 32b accumulators.
832 // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
833 // of destination, (Four 'd' registers total)
834 "vpadd.i32 d28, d0, d1\n"
835 "vpadd.i32 d29, d2, d3\n"
836 "vpadd.i32 d30, d4, d5\n"
837 "vpadd.i32 d31, d6, d7\n"
838
839 //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
840
841 // Logic to advance to the next block in preparation for the next
842 // iteration of the main loop. For now, we only want to compute
843 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
844 // not yet ready to update the values of row and col, as we still need
845 // the current values for the rest of the work on the current block.
846
847 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
848 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
849 "cmp r1, r3\n" // Have we finished the last row?
850
851 "bge 4f\n" // If finished last row, go to 4
852 // Not finished last row: then advance to next row.
853 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
854 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
855 "add r4, r4, r1, lsl #2\n"
856 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
857 "b 5f\n"
858 "4:\n" // Finished last row...
859 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
860 // Go back to first row
861 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
862
863 // Now we need to advance to the next column. If we already
864 // finished the last column, then in principle we are done, however
865 // we can't just return here, as we need to allow the end work of the
866 // current block to complete. The good news is that at this point it
867 // doesn't matter what data we load for the next column, since
868 // we will exit from the main loop below before actually storing
869 // anything computed from that data.
870
871 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
872 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
873 "cmp r8, r4\n" // Have we finished the last column?
874 "bge 5f\n" // If yes, just carry on without updating the column pointer.
875 // Not finished last column: then advance to next column.
876 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
877 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
878 "add r10, r10, r1, lsl #1\n"
879 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
880 "5:\n"
881
882 // Set the LHS and RHS data pointers to the start of the columns just
883 // computed.
884 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
885 "mov %[lhs_ptr], r4\n"
886 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
887 "mov %[rhs_ptr], r5\n"
888
889 // Now we load: bias data, LHS sums data, RHS sums data.
890
891 // First, load the base pointers from the params.
892 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
893 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
894
895 // Let r8 be stack offset of the row or column variable, whichever
896 // is the channel index.
897 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
898 "ite eq\n"
899 "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
900 "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
901 // Let r8 be the channel index.
902 "ldr r8, [sp, r8]\n"
903 // Compute the bias pointer, by conditionally using the channel index
904 // (r8) as offset into bias buffer (r1).
905 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
906 "it ne\n"
907 "addne r1, r1, r8, lsl #2\n"
908
909 // Load 2 bias values. When the channel dimension is rows, we will load
910 // another 2 bias values just before performing the bias addition below,
911 // as this kernel has a 4x2 rectangular shape.
912 "vld1.32 {d24}, [r1]!\n"
913
914 // Now that we know what LHS and RHS data the next iteration of the
915 // main loop will need to load, we start loading the first 32 bytes of
916 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
917 // in the rest of the work on the current block.
918 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
919 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
920 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
921 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
922
923 // Add to the bias values the product
924 // (depth * lhs_zero_point * rhs_zero_point),
925 // See the term NZ1Z2 in equation (7) in
926 // https://arxiv.org/pdf/1712.05877.pdf
927 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
928 "vdup.32 q9, r3\n"
929 "vadd.i32 d24, d24, d18\n"
930
931 // Perform the bias-addition (per the above, we have just folded into
932 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
933 // Jump based on channel dimension.
934 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
935 "bne 6f\n"
936 // Case where channels are rows.
937 // Load the remaining 2 bias values, since we're on the width-4 side
938 // of this 4x2 kernel.
939 "vld1.32 {d25}, [r1]\n"
940 "vadd.i32 d25, d25, d19\n"
941 "vadd.i32 q14, q14, q12\n"
942 "vadd.i32 q15, q15, q12\n"
943 "b 7f\n"
944
945 "6:\n"
946 // Case where channels are columns.
947 "vdup.32 q10, d24[0]\n"
948 "vdup.32 q11, d24[1]\n"
949 "vadd.i32 q14, q14, q10\n"
950 "vadd.i32 q15, q15, q11\n"
951 "7:\n"
952
953 // LHS/RHS zero points
954 // Has RHS sums
955 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
956 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
957 "beq 401f\n"
958 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
959 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
960 // Offset by current col * number of bytes per value
961 "add r3, r3, r4, lsl #2\n"
962 "vld1.32 { d12 }, [r3]\n"
963 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
964 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
965 // Subtract rhs_sums * lhs_zero_point, per
966 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
967 "vmls.i32 q14, q10, d12[0]\n"
968 "vmls.i32 q15, q10, d12[1]\n"
969 "401:\n"
970
971 // Has LHS sums
972 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
973 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
974 "beq 402f\n"
975 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
976 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
977 // Offset by current row * number of bytes per value
978 "add r2, r2, r4, lsl #2\n"
979 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
980
981 // Load 4 lhs_sums values.
982 "vld1.32 {d22, d23}, [r2]\n"
983 "vdup.32 d13, r5\n" // rhs_zero_point
984
985 // Compute lhs_sums * rhs_zero_point.
986 "vmul.i32 q11, q11, d13[1]\n"
987 // Subtract lhs_sums * rhs_zero_point, per
988 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
989 "vsub.s32 q14, q14, q11\n"
990 "vsub.s32 q15, q15, q11\n"
991
992 // If the destination is int32, it means the user asks for the raw
993 // accumulators, no need for us to downquantize the value.
994 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
995 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
996 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
997
998 "402:\n"
999
1000 // At this point we have computed the final int32 values. Now we
1001 // start down-quantizing them to obtain the final 8bit values from them.
1002
1003 // As part of this down-quantization, our int32 values will be
1004 // multiplied by a multiplier that has a fixed-point component and an
1005 // exponent component.
1006
1007 // Compute the data pointers for the multiplier data
1008 // r1 = exponent part
1009 // r2 = fixedpoint part
1010 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1011 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1012 // r6 has flags, r8 has channel index
1013 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1014 "it ne\n"
1015 "addne r1, r1, r8, lsl #2\n"
1016 "it ne\n"
1017 "addne r2, r2, r8, lsl #2\n"
1018
1019 // Load the first 2 values of multiplier exponent and fixedpoint data
1020 // Since this kernel is rectangular 4x2, we will only conditionally load
1021 // 2 more values below.
1022 "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent
1023 "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint
1024
1025 "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
1026 "vmvn.i32 q8, #0\n"
1027 "bne 8f\n"
1028 // Case where channels are rows.
1029 // Load the remaining 2 bias values, since we're on the width-4 side
1030 // of this 4x2 kernel.
1031 "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent
1032 "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint
1033 "vmin.s32 q11, q10, q8\n"
1034 "vsub.s32 q10, q10, q11\n"
1035
1036 // Apply the positive exponent part of the multiplier.
1037 "vshl.s32 q14, q14, q10\n"
1038 "vshl.s32 q15, q15, q10\n"
1039
1040 // Apply the fixed-point part of the multiplier.
1041 "vqdmulh.s32 q14, q14, q6\n"
1042 "vqdmulh.s32 q15, q15, q6\n"
1043
1044 // Apply the negative exponent part of the multiplier.
1045 "vrshl.s32 q14, q14, q11\n"
1046 "vrshl.s32 q15, q15, q11\n"
1047 "b 9f\n"
1048
1049 "8:\n"
1050 // Case where channels are columns.
1051 "vmin.s32 d22, d20, d16\n"
1052 "vsub.s32 d20, d20, d22\n"
1053
1054 // Apply the positive exponent part of the multiplier.
1055 "vdup.32 q12, d20[0]\n"
1056 "vdup.32 q13, d20[1]\n"
1057 "vshl.s32 q14, q14, q12\n"
1058 "vshl.s32 q15, q15, q13\n"
1059
1060 // Apply the fixed-point part of the multiplier.
1061 "vqdmulh.s32 q14, q14, d12[0]\n"
1062 "vqdmulh.s32 q15, q15, d12[1]\n"
1063
1064 // Apply the negative exponent part of the multiplier.
1065 "vdup.32 q12, d22[0]\n"
1066 "vdup.32 q13, d22[1]\n"
1067 "vrshl.s32 q14, q14, q12\n"
1068 "vrshl.s32 q15, q15, q13\n"
1069
1070 "9:\n"
1071
1072 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1073 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1074 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1075 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1076 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1077
1078 // Store uint8 values:
1079 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1080
1081 // Cast-and-saturate from int32 to int16
1082 // After this, all values for output are in q14.
1083 "vqmovn.s32 d28, q14\n"
1084 "vqmovn.s32 d29, q15\n"
1085
1086 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1087 // current block, so we can start clearing these accumulators for the
1088 // next block (next iteration of the main loop).
1089 RUY_MAKE_ZERO(q6)
1090 RUY_MAKE_ZERO(q7)
1091 RUY_MAKE_ZERO(q8)
1092 RUY_MAKE_ZERO(q9)
1093 RUY_MAKE_ZERO(q10)
1094 RUY_MAKE_ZERO(q11)
1095 RUY_MAKE_ZERO(q12)
1096 RUY_MAKE_ZERO(q13)
1097 RUY_MAKE_ZERO(q15)
1098
1099 // Load the destination zero point into each of the 8 16-bit slots
1100 // in a q register.
1101 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1102 "vdup.16 q13, r4\n" // dst_zero_point
1103
1104 // Add the destination zero point
1105 "vadd.i16 q14, q14, q13\n"
1106
1107 // Cast-and-saturate from int16 to uint8
1108 // Now all 8 1-byte values are in d30.
1109 "vqmovun.s16 d30, q14\n"
1110
1111 // Load the clamp_min, clamp_max bounds
1112 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1113 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1114 "vdup.8 d28, r2\n" // clamp_min
1115 "vdup.8 d29, r3\n" // clamp_max
1116
1117 // Apply the clamp_min bound
1118 "vmax.u8 d30, d30, d28\n"
1119 // Apply the clamp_max bound
1120 "vmin.u8 d30, d30, d29\n"
1121
1122 // Compute how much of the 4x2 block of destination 8bit values that
1123 // we have computed, fit in the destination matrix. Typically, all of
1124 // it fits, but when the destination matrix shape is not a multiple
1125 // of 4x2, there are some 4x2 blocks along the boundaries that do
1126 // not fit entirely.
1127
1128 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1129 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1130 "sub r1, r1, r8\n"
1131
1132 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1133 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1134 "sub r2, r2, r4\n"
1135 "mov r3, #4\n"
1136 "mov r5, #2\n"
1137 "cmp r1, #4\n"
1138 // Compute r1 = how many rows of the 4x2 block fit
1139 "it gt\n"
1140 "movgt r1, r3\n"
1141
1142 "cmp r2, #2\n"
1143 // Compute r2 = how many cols of the 4x2 block fit
1144 "it gt\n"
1145 "movgt r2, r5\n"
1146
1147 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1148 "cmp r1, r3\n"
1149 "it eq\n"
1150 "cmpeq r2, r5\n"
1151 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1152 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1153 // Yes, all of the 4x2 block fits, go to fast path.
1154 "beq 30f\n"
1155 // Not all of the 4x2 block fits.
1156 // Store to dst_tmp_buf
1157 // Set r3 address to write to dst_tmp_buf.
1158 "mov r3, %[dst_tmp_buf]\n"
1159 "vst1.8 {d30}, [r3]\n"
1160
1161 // Slow loop copying from dst_tmp_buf to dst.
1162 "mov r6, #0\n"
1163 "50:\n"
1164 "mov r8, #0\n"
1165 "51:\n"
1166 "ldrb r10, [r3, r8]\n"
1167 "strb r10, [r4, r8]\n"
1168 "add r8, r8, #1\n"
1169 "cmp r8, r1\n"
1170 "blt 51b\n"
1171 "add r6, r6, #1\n"
1172 "add r3, r3, #4\n"
1173 "add r4, r4, r5\n"
1174 "cmp r6, r2\n"
1175 "blt 50b\n"
1176 "b 31f\n"
1177 "30:\n"
1178 // Yes, all of the 4x2 block fits.
1179 // r3 address, r5 stride
1180 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1181 "mov r4, r3\n"
1182 "mov r6, #1\n"
1183
1184 "vst1.32 {d30[0]}, [r3]\n"
1185 "add r4, r4, r5\n"
1186 "mov r3, r4\n"
1187 "vst1.32 {d30[1]}, [r3]\n"
1188
1189 "31:\n"
1190
1191 // Load dst_ptr, increment, and write back.
1192 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1193 "add r4, r4, #4\n"
1194 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1195
1196 RUY_MAKE_ZERO(q13)
1197 RUY_MAKE_ZERO(q14)
1198 RUY_MAKE_ZERO(q15)
1199
1200 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1201
1202 // Store int8 values:
1203 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1204
1205 // Cast-and-saturate from int32 to int16
1206 // After this, all values for output are in q14.
1207 "vqmovn.s32 d28, q14\n"
1208 "vqmovn.s32 d29, q15\n"
1209
1210 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1211 // current block, so we can start clearing these accumulators for the
1212 // next block (next iteration of the main loop).
1213 RUY_MAKE_ZERO(q6)
1214 RUY_MAKE_ZERO(q7)
1215 RUY_MAKE_ZERO(q8)
1216 RUY_MAKE_ZERO(q9)
1217 RUY_MAKE_ZERO(q10)
1218 RUY_MAKE_ZERO(q11)
1219 RUY_MAKE_ZERO(q12)
1220 RUY_MAKE_ZERO(q13)
1221 RUY_MAKE_ZERO(q15)
1222
1223 // Load the destination zero point into each of the 8 16-bit slots
1224 // in a q register.
1225 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1226 "vdup.16 q13, r4\n" // dst_zero_point
1227
1228 // Add the destination zero point
1229 "vadd.i16 q14, q14, q13\n"
1230
1231 // Cast-and-saturate from int16 to int8
1232 // Now all 8 1-byte values are in d30.
1233 "vqmovn.s16 d30, q14\n"
1234
1235 // Load the clamp_min, clamp_max bounds
1236 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1237 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1238 "vdup.8 d28, r2\n" // clamp_min
1239 "vdup.8 d29, r3\n" // clamp_max
1240
1241 // Apply the clamp_min bound
1242 "vmax.s8 d30, d30, d28\n"
1243 // Apply the clamp_max bound
1244 "vmin.s8 d30, d30, d29\n"
1245
1246 // Compute how much of the 4x2 block of destination 8bit values that
1247 // we have computed, fit in the destination matrix. Typically, all of
1248 // it fits, but when the destination matrix shape is not a multiple
1249 // of 4x2, there are some 4x2 blocks along the boundaries that do
1250 // not fit entirely.
1251
1252 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1253 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1254 "sub r1, r1, r8\n"
1255
1256 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1257 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1258 "sub r2, r2, r4\n"
1259 "mov r3, #4\n"
1260 "mov r5, #2\n"
1261 "cmp r1, #4\n"
1262 // Compute r1 = how many rows of the 4x2 block fit
1263 "it gt\n"
1264 "movgt r1, r3\n"
1265
1266 "cmp r2, #2\n"
1267 // Compute r2 = how many cols of the 4x2 block fit
1268 "it gt\n"
1269 "movgt r2, r5\n"
1270
1271 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1272 "cmp r1, r3\n"
1273 "it eq\n"
1274 "cmpeq r2, r5\n"
1275 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1276 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1277 // Yes, all of the 4x2 block fits, go to fast path.
1278 "beq 30f\n"
1279 // Not all of the 4x2 block fits.
1280 // Store to dst_tmp_buf
1281 // Set r3 address to write to dst_tmp_buf.
1282 "mov r3, %[dst_tmp_buf]\n"
1283 "vst1.8 {d30}, [r3]\n"
1284
1285 // Slow loop copying from dst_tmp_buf to dst.
1286 "mov r6, #0\n"
1287 "50:\n"
1288 "mov r8, #0\n"
1289 "51:\n"
1290 "ldrb r10, [r3, r8]\n"
1291 "strb r10, [r4, r8]\n"
1292 "add r8, r8, #1\n"
1293 "cmp r8, r1\n"
1294 "blt 51b\n"
1295 "add r6, r6, #1\n"
1296 "add r3, r3, #4\n"
1297 "add r4, r4, r5\n"
1298 "cmp r6, r2\n"
1299 "blt 50b\n"
1300 "b 31f\n"
1301 "30:\n"
1302 // Yes, all of the 4x2 block fits.
1303 // r3 address, r5 stride
1304 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1305 "mov r4, r3\n"
1306 "mov r6, #1\n"
1307
1308 "vst1.32 {d30[0]}, [r3]\n"
1309 "add r4, r4, r5\n"
1310 "mov r3, r4\n"
1311 "vst1.32 {d30[1]}, [r3]\n"
1312
1313 "31:\n"
1314
1315 // Load dst_ptr, increment, and write back.
1316 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1317 "add r4, r4, #4\n"
1318 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1319
1320 RUY_MAKE_ZERO(q13)
1321 RUY_MAKE_ZERO(q14)
1322 RUY_MAKE_ZERO(q15)
1323
1324 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1325
1326 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1327
1328 // Load the destination zero point into each of the 4 32-bit slots
1329 // in a q register.
1330 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1331 "vdup.32 q13, r4\n" // dst_zero_point
1332 // Add the destination zero point
1333 "vadd.s32 q14, q14, q13\n"
1334 "vadd.s32 q15, q15, q13\n"
1335
1336 // Cast-and-saturate from int32 to int16
1337 // After this, all values for output are in q14.
1338 "vqmovn.s32 d28, q14\n"
1339 "vqmovn.s32 d29, q15\n"
1340
1341 // At this point, v18 -- v31 aren't used anymore for the current block,
1342 // so we can start clearing these accumulators for the next block
1343 // (next iteration of the main loop).
1344 RUY_MAKE_ZERO(q6)
1345 RUY_MAKE_ZERO(q7)
1346 RUY_MAKE_ZERO(q8)
1347 RUY_MAKE_ZERO(q9)
1348 RUY_MAKE_ZERO(q10)
1349 RUY_MAKE_ZERO(q11)
1350 RUY_MAKE_ZERO(q15)
1351
1352 // Load the clamp_min, clamp_max bounds
1353 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1354 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1355 "vdup.16 q12, r2\n" // clamp_min
1356 "vdup.16 q13, r3\n" // clamp_max
1357
1358 // Apply the clamp_min bound
1359 "vmax.s16 q14, q14, q12\n"
1360 // Apply the clamp_max bound
1361 "vmin.s16 q14, q14, q13\n"
1362
1363 RUY_MAKE_ZERO(q12)
1364 RUY_MAKE_ZERO(q13)
1365
1366 // Compute how much of the 4x2 block of destination 16-bit values that
1367 // we have computed, fit in the destination matrix. Typically, all of
1368 // it fits, but when the destination matrix shape is not a multiple
1369 // of 4x2, there are some 4x2 blocks along the boundaries that do
1370 // not fit entirely.
1371
1372 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1373 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1374 "sub r1, r1, r8\n"
1375
1376 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1377 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1378 "sub r2, r2, r4\n"
1379 "mov r3, #4\n"
1380 "mov r5, #2\n"
1381 "cmp r1, #4\n"
1382 // Compute r1 = how many rows of the 4x2 block fit
1383 "it gt\n"
1384 "movgt r1, r3\n"
1385
1386 "cmp r2, #2\n"
1387 // Compute r2 = how many cols of the 4x2 block fit
1388 "it gt\n"
1389 "movgt r2, r5\n"
1390
1391 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1392 "cmp r1, r3\n"
1393 "it eq\n"
1394 "cmpeq r2, r5\n"
1395 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1396 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1397 // Yes, all of the 4x2 block fits, go to fast path.
1398 "beq 30f\n"
1399 // Not all of the 4x2 block fits.
1400 // Store to dst_tmp_buf
1401 // Set r3 address to write to dst_tmp_buf.
1402 "mov r3, %[dst_tmp_buf]\n"
1403 "vst1.16 {q14}, [r3]\n"
1404
1405 // Slow loop copying from dst_tmp_buf to dst.
1406 "mov r6, #0\n"
1407 "50:\n"
1408 "mov r8, #0\n"
1409 "51:\n"
1410 // Shift of offset register for half-word loads not allowed in A32,
1411 // so we shift, load/store, then shift back r8.
1412 "lsl r8, r8, #1\n"
1413 "ldrh r10, [r3, r8]\n"
1414 "strh r10, [r4, r8]\n"
1415 "lsr r8, r8, #1\n"
1416 "add r8, r8, #1\n"
1417 "cmp r8, r1\n"
1418 "blt 51b\n"
1419 "add r6, r6, #1\n"
1420 "add r3, r3, #8\n"
1421 "add r4, r4, r5\n"
1422 "cmp r6, r2\n"
1423 "blt 50b\n"
1424 "b 31f\n"
1425 "30:\n"
1426 // Yes, all of the 4x2 block fits.
1427 // r3 address, r5 stride
1428 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1429 "mov r4, r3\n"
1430 "mov r6, #2\n"
1431
1432 "vst1.16 {d28[0]}, [r3], r6\n"
1433 "add r4, r4, r5\n"
1434 "vst1.16 {d28[1]}, [r3], r6\n"
1435 "vst1.16 {d28[2]}, [r3], r6\n"
1436 "vst1.16 {d28[3]}, [r3], r6\n"
1437 "mov r3, r4\n"
1438 "vst1.16 {d29[0]}, [r3], r6\n"
1439 "vst1.16 {d29[1]}, [r3], r6\n"
1440 "vst1.16 {d29[2]}, [r3], r6\n"
1441 "vst1.16 {d29[3]}, [r3], r6\n"
1442 "31:\n"
1443
1444 // Load dst_ptr, increment, and write back.
1445 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1446 "add r4, r4, #8\n"
1447 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1448
1449 RUY_MAKE_ZERO(q14)
1450
1451 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1452
1453 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1454
1455 // Since the store type is the same as the accum type, no need for
1456 // downcast. There's also no need for clamp by min/max.
1457
1458 // At this point, v20 -- v31 aren't used anymore for the current block,
1459 // so we can start clearing these accumulators for the next block
1460 // (next iteration of the main loop).
1461 // Clear accumulators.
1462 RUY_MAKE_ZERO(q6)
1463 RUY_MAKE_ZERO(q7)
1464 RUY_MAKE_ZERO(q8)
1465 RUY_MAKE_ZERO(q9)
1466 RUY_MAKE_ZERO(q10)
1467 RUY_MAKE_ZERO(q11)
1468 RUY_MAKE_ZERO(q12)
1469 RUY_MAKE_ZERO(q13)
1470
1471 // Compute how much of the 4x2 block of destination 32 bit values that
1472 // we have computed, fit in the destination matrix. Typically, all of
1473 // it fits, but when the destination matrix shape is not a multiple
1474 // of 4x2, there are some 4x4 blocks along the boundaries that do
1475 // not fit entirely.
1476
1477 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1478 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1479 "sub r1, r1, r8\n"
1480
1481 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1482 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1483 "sub r2, r2, r4\n"
1484 "mov r3, #4\n"
1485 "mov r5, #2\n"
1486 "cmp r1, #4\n"
1487 // Compute r1 = how many rows of the 4x2 block fit
1488 "it gt\n"
1489 "movgt r1, r3\n"
1490
1491 "cmp r2, #2\n"
1492 // Compute r2 = how many cols of the 4x2 block fit
1493 "it gt\n"
1494 "movgt r2, r5\n"
1495
1496 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1497 "cmp r1, r3\n"
1498 "it eq\n"
1499 "cmpeq r2, r5\n"
1500 // Yes, all of the 4x2 block fits, go to fast path.
1501 "beq 30f\n"
1502 // Not all of the 4x2 block fits.
1503 // Set (r3 address, r4 stride) to write to dst_tmp_buf
1504 "mov r3, %[dst_tmp_buf]\n"
1505 "mov r4, #16\n"
1506 "b 31f\n"
1507
1508 "30:\n"
1509 // Yes, all of the 4x2 block fits.
1510 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1511 // r3 address, r4 stride
1512 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1513 "mov r4, r5\n"
1514
1515 "31:\n"
1516
1517 "vst1.32 {d28, d29}, [r3]\n"
1518 "add r3, r3, r4\n"
1519 "vst1.32 {d30, d31}, [r3]\n"
1520
1521 // If all of the 4x2 block fits, we just finished writing it to the
1522 // destination, so we skip the next part.
1523 "beq 41f\n"
1524 // Not all of the 4x2 block fits in the destination matrix. We just
1525 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
1526 // it to copy into the destination matrix the part that fits.
1527 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1528 "mov r3, %[dst_tmp_buf]\n"
1529 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1530 "mov r6, #0\n"
1531 "50:\n"
1532 "mov r5, #0\n"
1533 "51:\n"
1534 "ldr r10, [r3, r5, lsl #2]\n"
1535 "str r10, [r4, r5, lsl #2]\n"
1536 "add r5, r5, #1\n"
1537 "cmp r5, r1\n"
1538 "blt 51b\n"
1539 "add r6, r6, #1\n"
1540 "add r3, r3, #16\n"
1541 "add r4, r4, r8\n"
1542 // r2 = how many cols of the 8x4 block fit
1543 "cmp r6, r2\n"
1544 "blt 50b\n"
1545
1546 "41:\n"
1547 // Load dst_ptr, increment, and write back.
1548 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1549 "add r4, r4, #16\n"
1550 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1551
1552 RUY_MAKE_ZERO(q10)
1553 RUY_MAKE_ZERO(q11)
1554
1555 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1556
1557 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1558
1559 // Reload some params --- we had used x5 -- x7 for a few other things
1560 // since the last time we had loaded them.
1561 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1562 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1563 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1564
1565 // Move to the next block of the destination matrix, for the next iter
1566 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1567 // been updated earlier.
1568 // Have we reached the end row?
1569 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1570 "cmp r8, r3\n"
1571
1572 "beq 20f\n" // yes, end row.
1573 // Not end row. Move to the next row.
1574 "add r8, r8, #4\n"
1575 // Store new value of row
1576 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1577
1578 "b 21f\n"
1579 "20:\n"
1580 // Was already at end row.
1581 // Move back to first row.
1582 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1583 // Move to the next column.
1584 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1585 "add r4, r4, #2\n"
1586 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1587
1588 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1589 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1590 // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
1591 "add r1, r1, r8, lsl #1\n"
1592 // Store dst_col_ptr
1593 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1594 // Store dst_ptr
1595 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1596 "21:\n"
1597
1598 // Main loop exit condition: have we hit the end column?
1599 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1600 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1601 "cmp r8, r4\n"
1602
1603 // w1 is the number of levels of depth that we have already loaded
1604 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1605 // above, this is currently 16.
1606 "mov r1, #16\n"
1607
1608 "ble 1b\n"
1609
1610 // Restore stack pointer.
1611 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1612
1613 // clang-format on
1614
1615 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
1616 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
1617 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
1618 // Clobber list must specify q registers (and not their constituent
1619 // d registers). There is a (currently unexplained) slowdown if
1620 // d registers are listed in the clobbers list.
1621 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
1622 "q9", "q10", "q12", "q13", "q14", "q15");
1623 }
1624
1625 // Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
1626 // is still packed as if it has two columns
Kernel8bitNeon1Col(const KernelParams8bit<4,2> & params)1627 void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
1628 profiler::ScopeLabel label("Kernel (kNeon)");
1629
1630 CheckOffsetsInKernelParams8bit(params);
1631
1632 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1633 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1634 const std::int8_t* lhs_ptr = lhs_col_ptr;
1635 const std::int8_t* rhs_ptr = rhs_col_ptr;
1636
1637 RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
1638
1639 // The asm kernel below has the following NEON register allocation:
1640 //
1641 // q6 - q13 are 128-bit (4x32b) accumulators.
1642 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
1643 // d8 -- d11 from RHS:
1644 // int8 RHS 16x1 block
1645 // /------------|
1646 // | d8.b[0] |
1647 // | ... |
1648 // | d8.b[7] |
1649 // | d9.b[0] |
1650 // | ... |
1651 // | d9.b[7] |
1652 // \------------/
1653 // int8 LHS 4x16 block
1654 // /-----------------------------------------\ /------------|
1655 // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
1656 // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
1657 // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
1658 // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
1659 // \-----------------------------------------/ \------------/
1660 // 128-bit accumulators 4x1 block
1661 //
1662 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1663 // optimization for this kernel.
1664 asm volatile(
1665 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
1666
1667 // clang-format off
1668
1669 // Load the first 64 bytes of LHS and RHS data.
1670 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1671 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1672 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1673 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1674 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1675 // Skip the other column and advance the pointer.
1676 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1677
1678 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1679
1680 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1681 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1682
1683 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1684 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1685
1686 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1687 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1688
1689 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
1690 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1691
1692 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1693 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1694
1695 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
1696 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1697
1698 // Clear accumulators.
1699 RUY_MAKE_ZERO(q6)
1700 RUY_MAKE_ZERO(q7)
1701 RUY_MAKE_ZERO(q8)
1702 RUY_MAKE_ZERO(q9)
1703 RUY_MAKE_ZERO(q10)
1704 RUY_MAKE_ZERO(q11)
1705 RUY_MAKE_ZERO(q12)
1706 RUY_MAKE_ZERO(q13)
1707 RUY_MAKE_ZERO(q14)
1708 RUY_MAKE_ZERO(q15)
1709
1710 // r1 is the number of levels of depth that we have already loaded
1711 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1712 // above, this is currently 16.
1713 "mov r1, #16\n"
1714
1715 // Main loop of the whole GEMM, over rows and columns of the
1716 // destination matrix.
1717 "1:\n"
1718
1719 // r1 is how many levels of depth we have already loaded
1720 // data for, r10 is the total depth.
1721 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1722 "cmp r1, r10\n"
1723 "beq 79f\n"
1724
1725 "2:\n"
1726
1727 // Mult, mult-acc in to q14, q15
1728 "vmull.s8 q14, d0, d8\n"
1729 "vmull.s8 q15, d2, d8\n"
1730 "vmlal.s8 q14, d1, d9\n"
1731 "vmlal.s8 q15, d3, d9\n"
1732
1733 // Then pairwise accumulate in to q6, q7
1734 "vpadal.s16 q6, q14\n"
1735 "vpadal.s16 q7, q15\n"
1736
1737 // Mult, mult-acc in to q14, q15
1738 "vmull.s8 q14, d4, d8\n"
1739 "vmull.s8 q15, d6, d8\n"
1740 "vmlal.s8 q14, d5, d9\n"
1741 "vmlal.s8 q15, d7, d9\n"
1742
1743 // Then pairwise accumulate in to q8, q9
1744 "vpadal.s16 q8, q14\n"
1745 "vpadal.s16 q9, q15\n"
1746
1747
1748 // Load the next 64 bytes of LHS and RHS data.
1749 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1750 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1751 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1752 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1753 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1754 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1755 // Skip the other column and advance the pointer.
1756 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1757 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1758
1759 // Each iteration of this loop advances by 16 levels of depth.
1760 "add r1, r1, #16\n"
1761
1762 // Loop termination condition
1763 "cmp r1, r10\n"
1764
1765 "blt 2b\n"
1766
1767 "79:\n"
1768
1769 // Mult, mult-acc in to q14, q15
1770 "vmull.s8 q14, d0, d8\n"
1771 "vmull.s8 q15, d2, d8\n"
1772 "vmlal.s8 q14, d1, d9\n"
1773 "vmlal.s8 q15, d3, d9\n"
1774
1775 // Then pairwise accumulate in to q6, q7
1776 "vpadal.s16 q6, q14\n"
1777 "vpadal.s16 q7, q15\n"
1778
1779 // Mult, mult-acc in to q14, q15
1780 "vmull.s8 q14, d4, d8\n"
1781 "vmull.s8 q15, d6, d8\n"
1782 "vmlal.s8 q14, d5, d9\n"
1783 "vmlal.s8 q15, d7, d9\n"
1784
1785 // Then pairwise accumulate in to q8, q9
1786 "vpadal.s16 q8, q14\n"
1787 "vpadal.s16 q9, q15\n"
1788
1789 // All accumulation over depth done. q6 - q9 contain the 4x32b
1790 // accumulators for the 4x1 final matrix.
1791 // We now have to compute the final 8-bit values from these int32
1792 // accumulators, and advance to the next 4x2 block. We intertwine
1793 // these two aspects whenever possible for optimal pipelining, both
1794 // at the data flow level (prefetch data for next block as early as
1795 // possible) and instruction pipelining level (some of the next-block
1796 // work can dual-issue with some of the final work on the current
1797 // block).
1798
1799 // q6-q9 now contain 4 x 32b
1800 "vpadd.i32 d0, d12, d13\n"
1801 "vpadd.i32 d1, d14, d15\n"
1802 "vpadd.i32 d2, d16, d17\n"
1803 "vpadd.i32 d3, d18, d19\n"
1804
1805 // d0-d4 each contain 2 x 32b accumulators.
1806 // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
1807 // of destination, (Four 'd' registers total)
1808 "vpadd.i32 d28, d0, d1\n"
1809 "vpadd.i32 d29, d2, d3\n"
1810
1811 // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
1812
1813 // Logic to advance to the next block in preparation for the next
1814 // iteration of the main loop. For now, we only want to compute
1815 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1816 // not yet ready to update the values of row and col, as we still need
1817 // the current values for the rest of the work on the current block.
1818
1819 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1820 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1821 "cmp r1, r3\n" // Have we finished the last row?
1822
1823 "bge 4f\n" // If finished last row, go to 4
1824 // Not finished last row: then advance to next row.
1825 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1826 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1827 "add r4, r4, r1, lsl #2\n"
1828 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1829 "b 5f\n"
1830 "4:\n" // Finished last row...
1831 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1832 // Go back to first row
1833 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1834
1835 // Now we need to advance to the next column. If we already
1836 // finished the last column, then in principle we are done, however
1837 // we can't just return here, as we need to allow the end work of the
1838 // current block to complete. The good news is that at this point it
1839 // doesn't matter what data we load for the next column, since
1840 // we will exit from the main loop below before actually storing
1841 // anything computed from that data.
1842
1843 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1844 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1845 "cmp r8, r4\n" // Have we finished the last column?
1846 "bge 5f\n" // If yes, just carry on without updating the column pointer.
1847 // Not finished last column: then advance to next column.
1848 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1849 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1850 "add r10, r10, r1, lsl #1\n"
1851 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1852 "5:\n"
1853
1854 // Set the LHS and RHS data pointers to the start of the columns just
1855 // computed.
1856 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1857 "mov %[lhs_ptr], r4\n"
1858 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1859 "mov %[rhs_ptr], r5\n"
1860
1861 // Now we load: bias data, LHS sums data, RHS sums data.
1862
1863 // First, load the base pointers from the params.
1864 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1865 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1866
1867 // Offset these base pointers as needed given the current row, col.
1868 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1869
1870 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1871 "it ne\n"
1872 "addne r1, r1, r8, lsl #2\n"
1873
1874 // Load 4 bias values.
1875 "vld1.32 {d24, d25}, [r1]\n"
1876
1877 // Now that we know what LHS and RHS data the next iteration of the
1878 // main loop will need to load, we start loading the first 32 bytes of
1879 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1880 // in the rest of the work on the current block.
1881 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1882 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1883 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1884 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1885 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1886 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1887 // Skip the other column and advance the pointer.
1888 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1889 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1890
1891 // Add to the bias values the product
1892 // (depth * lhs_zero_point * rhs_zero_point),
1893 // See the term NZ1Z2 in equation (7) in
1894 // https://arxiv.org/pdf/1712.05877.pdf
1895 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1896 "vdup.32 q9, r3\n"
1897 "vadd.i32 q12, q12, q9\n"
1898
1899 // Perform the bias-addition (per the above, we have just folded into
1900 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1901 "vadd.i32 q14, q14, q12\n"
1902
1903 // LHS/RHS zero points
1904 // Has RHS sums
1905 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1906 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1907 "beq 401f\n"
1908 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1909 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1910 // Offset by current col * number of bytes per value
1911 "add r3, r3, r4, lsl #2\n"
1912 "vld1.32 { d12 }, [r3]\n"
1913 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1914 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
1915 // Subtract rhs_sums * lhs_zero_point, per
1916 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1917 "vmls.i32 q14, q10, d12[0]\n"
1918 "401:\n"
1919
1920 // Has LHS sums
1921 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1922 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1923 "beq 402f\n"
1924 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1925 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1926 // Offset by current row * number of bytes per value
1927 "add r2, r2, r4, lsl #2\n"
1928 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1929
1930 // Load 4 lhs_sums values.
1931 "vld1.32 {d22, d23}, [r2]\n"
1932 "vdup.32 d13, r5\n" // rhs_zero_point
1933
1934 // Compute lhs_sums * rhs_zero_point.
1935 "vmul.i32 q11, q11, d13[1]\n"
1936 // Subtract lhs_sums * rhs_zero_point, per
1937 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1938 "vsub.s32 q14, q14, q11\n"
1939
1940 // If the destination is int32, it means the user asks for the raw
1941 // accumulators, no need for us to downquantize the value.
1942 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1943 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1944 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1945
1946 "402:\n"
1947
1948 // At this point we have computed the final int32 values. Now we
1949 // start down-quantizing them to obtain the final 8bit values from them.
1950
1951 // As part of this down-quantization, our int32 values will be
1952 // multiplied by a multiplier that has a fixed-point component and an
1953 // exponent component.
1954
1955 //Load the exponent part of the multiplier.
1956 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1957 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1958 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1959 "it ne\n"
1960 "addne r1, r1, r4, lsl #2\n"
1961
1962 "vld1.32 {q10}, [r1]\n"
1963
1964 "vmvn.i32 q8, #0\n"
1965 "vmin.s32 q13, q10, q8\n"
1966 "vsub.s32 q12, q10, q13\n"
1967
1968 // Apply the positive exponent part of the multiplier.
1969 "vshl.s32 q14, q14, q12\n"
1970
1971 // Load fixed point part of the multiplier
1972 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1973 // r6 has flags, r4 has row
1974 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1975 "it ne\n"
1976 "addne r1, r1, r4, lsl #2\n"
1977 "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
1978
1979 // Apply the fixed-point part of the multiplier.
1980 "vqdmulh.s32 q14, q14, q10\n"
1981
1982 // Apply the negative exponent part of the multiplier.
1983 "vrshl.s32 q14, q14, q13\n"
1984
1985 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1986 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1987 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1988 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1989 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1990
1991 // Store uint8 values:
1992 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1993
1994 // Cast-and-saturate from int32 to int16
1995 // After this, all values for output are in d28.
1996 "vqmovn.s32 d28, q14\n"
1997
1998 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
1999 // current block, so we can start clearing these accumulators for the
2000 // next block (next iteration of the main loop).
2001 RUY_MAKE_ZERO(q6)
2002 RUY_MAKE_ZERO(q7)
2003 RUY_MAKE_ZERO(q8)
2004 RUY_MAKE_ZERO(q9)
2005 RUY_MAKE_ZERO(q10)
2006 RUY_MAKE_ZERO(q11)
2007 RUY_MAKE_ZERO(q12)
2008 RUY_MAKE_ZERO(q13)
2009 RUY_MAKE_ZERO(q15)
2010
2011 // Load the destination zero point into each of the 8 16-bit slots
2012 // in a q register.
2013 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2014 "vdup.16 q13, r4\n" // dst_zero_point
2015
2016 // Add the destination zero point
2017 "vadd.i16 q14, q14, q13\n"
2018
2019 // Cast-and-saturate from int16 to uint8
2020 "vqmovun.s16 d30, q14\n"
2021 // At this point, we only need 4 8-bit values in the lower half
2022 // of d30.
2023
2024
2025 // Load the clamp_min, clamp_max bounds
2026 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2027 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2028 "vdup.8 d28, r2\n" // clamp_min
2029 "vdup.8 d29, r3\n" // clamp_max
2030
2031 // Apply the clamp_min bound
2032 "vmax.u8 d30, d30, d28\n"
2033 // Apply the clamp_max bound
2034 "vmin.u8 d30, d30, d29\n"
2035
2036 // Compute how much of the 4x1 block of destination 8bit values that
2037 // we have computed, fit in the destination matrix. Typically, all of
2038 // it fits, but when the destination matrix shape is not a multiple
2039 // of 4x1, there are some 4x1 blocks along the boundaries that do
2040 // not fit entirely.
2041
2042 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2043 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2044 "sub r1, r1, r8\n"
2045
2046 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2047 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2048 "sub r2, r2, r4\n"
2049 "mov r3, #4\n"
2050 "mov r5, #2\n"
2051 "cmp r1, #4\n"
2052 // Compute r1 = how many rows of the 4x1 block fit
2053 "it gt\n"
2054 "movgt r1, r3\n"
2055
2056 // Test if r1==4, i.e. if all of the 4x1 block fits.
2057 "cmp r1, r3\n"
2058
2059 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2060 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2061 // Yes, all of the 4x1 block fits, go to fast path.
2062 "beq 30f\n"
2063 // Not all of the 4x1 block fits.
2064 // Store to dst_tmp_buf
2065 // Set r3 address to write to dst_tmp_buf.
2066 "mov r3, %[dst_tmp_buf]\n"
2067 "vst1.8 {d30}, [r3]\n"
2068
2069 // Slow loop copying from dst_tmp_buf to dst.
2070 "50:\n"
2071 "mov r8, #0\n"
2072 "51:\n"
2073 "ldrb r10, [r3, r8]\n"
2074 "strb r10, [r4, r8]\n"
2075 "add r8, r8, #1\n"
2076 "cmp r8, r1\n"
2077 "blt 51b\n"
2078 "b 31f\n"
2079 "30:\n"
2080 // Yes, all of the 4x1 block fits.
2081 // r3 address, r5 stride
2082 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2083 "mov r4, r3\n"
2084 "mov r6, #1\n"
2085
2086 "vst1.8 {d30[0]}, [r3], r6\n"
2087 "vst1.8 {d30[1]}, [r3], r6\n"
2088 "vst1.8 {d30[2]}, [r3], r6\n"
2089 "vst1.8 {d30[3]}, [r3], r6\n"
2090 "31:\n"
2091
2092 // Load dst_ptr, increment, and write back.
2093 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2094 "add r4, r4, #4\n"
2095 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2096
2097 RUY_MAKE_ZERO(q13)
2098 RUY_MAKE_ZERO(q14)
2099 RUY_MAKE_ZERO(q15)
2100
2101 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2102
2103 // Store int8 values:
2104 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2105
2106 // Cast-and-saturate from int32 to int16
2107 // After this, all values for output are in d28.
2108 "vqmovn.s32 d28, q14\n"
2109
2110 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2111 // current block, so we can start clearing these accumulators for the
2112 // next block (next iteration of the main loop).
2113 RUY_MAKE_ZERO(q6)
2114 RUY_MAKE_ZERO(q7)
2115 RUY_MAKE_ZERO(q8)
2116 RUY_MAKE_ZERO(q9)
2117 RUY_MAKE_ZERO(q10)
2118 RUY_MAKE_ZERO(q11)
2119 RUY_MAKE_ZERO(q12)
2120 RUY_MAKE_ZERO(q13)
2121 RUY_MAKE_ZERO(q15)
2122
2123 // Load the destination zero point into each of the 8 16-bit slots
2124 // in a q register.
2125 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2126 "vdup.16 q13, r4\n" // dst_zero_point
2127
2128 // Add the destination zero point
2129 "vadd.i16 q14, q14, q13\n"
2130
2131 // Cast-and-saturate from int16 to int8
2132 "vqmovn.s16 d30, q14\n"
2133 // At this point, we only need 4 8-bit values in the lower half
2134 // of d30.
2135
2136 // Load the clamp_min, clamp_max bounds
2137 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2138 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2139 "vdup.8 d28, r2\n" // clamp_min
2140 "vdup.8 d29, r3\n" // clamp_max
2141
2142 // Apply the clamp_min bound
2143 "vmax.s8 d30, d30, d28\n"
2144 // Apply the clamp_max bound
2145 "vmin.s8 d30, d30, d29\n"
2146
2147 // Compute how much of the 4x1 block of destination 8bit values that
2148 // we have computed, fit in the destination matrix. Typically, all of
2149 // it fits, but when the destination matrix shape is not a multiple
2150 // of 4x2, there are some 4x2 blocks along the boundaries that do
2151 // not fit entirely.
2152
2153 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2154 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2155 "sub r1, r1, r8\n"
2156
2157 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2158 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2159 "sub r2, r2, r4\n"
2160 "mov r3, #4\n"
2161 "mov r5, #2\n"
2162 "cmp r1, #4\n"
2163 // Compute r1 = how many rows of the 4x2 block fit
2164 "it gt\n"
2165 "movgt r1, r3\n"
2166
2167 // Test if r1==4 i.e. if all of the 4x1 block fits.
2168 "cmp r1, r3\n"
2169
2170 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2171 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2172 // Yes, all of the 4x2 block fits, go to fast path.
2173 "beq 30f\n"
2174 // Not all of the 4x2 block fits.
2175 // Store to dst_tmp_buf
2176 // Set r3 address to write to dst_tmp_buf.
2177 "mov r3, %[dst_tmp_buf]\n"
2178 "vst1.8 {d30}, [r3]\n"
2179
2180 // Slow loop copying from dst_tmp_buf to dst.
2181 "50:\n"
2182 "mov r8, #0\n"
2183 "51:\n"
2184 "ldrb r10, [r3, r8]\n"
2185 "strb r10, [r4, r8]\n"
2186 "add r8, r8, #1\n"
2187 "cmp r8, r1\n"
2188 "blt 51b\n"
2189 "b 31f\n"
2190 "30:\n"
2191 // Yes, all of the 4x1 block fits.
2192 // r3 address, r5 stride
2193 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2194 "mov r4, r3\n"
2195 "mov r6, #1\n"
2196
2197 "vst1.8 {d30[0]}, [r3], r6\n"
2198 "vst1.8 {d30[1]}, [r3], r6\n"
2199 "vst1.8 {d30[2]}, [r3], r6\n"
2200 "vst1.8 {d30[3]}, [r3], r6\n"
2201 "31:\n"
2202
2203 // Load dst_ptr, increment, and write back.
2204 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2205 "add r4, r4, #4\n"
2206 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2207
2208 RUY_MAKE_ZERO(q13)
2209 RUY_MAKE_ZERO(q14)
2210 RUY_MAKE_ZERO(q15)
2211
2212 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2213
2214 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2215
2216 // Load the destination zero point into each of the 4 32-bit slots
2217 // in a q register.
2218 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2219 "vdup.32 q13, r4\n" // dst_zero_point
2220 // Add the destination zero point
2221 "vadd.s32 q14, q14, q13\n"
2222 //"vadd.s32 q15, q15, q13\n"
2223
2224 // Cast-and-saturate from int32 to int16
2225 // After this, all values for output are in d28.
2226 "vqmovn.s32 d28, q14\n"
2227
2228 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2229 // so we can start clearing these accumulators for the next block
2230 // (next iteration of the main loop).
2231 RUY_MAKE_ZERO(q6)
2232 RUY_MAKE_ZERO(q7)
2233 RUY_MAKE_ZERO(q8)
2234 RUY_MAKE_ZERO(q9)
2235 RUY_MAKE_ZERO(q10)
2236 RUY_MAKE_ZERO(q11)
2237 RUY_MAKE_ZERO(q15)
2238
2239 // Load the clamp_min, clamp_max bounds
2240 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2241 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2242 "vdup.16 d24, r2\n" // clamp_min
2243 "vdup.16 d26, r3\n" // clamp_max
2244
2245 // Apply the clamp_min bound
2246 "vmax.s16 d28, d28, d24\n"
2247 // Apply the clamp_max bound
2248 "vmin.s16 d28, d28, d26\n"
2249
2250 RUY_MAKE_ZERO(q12)
2251 RUY_MAKE_ZERO(q13)
2252
2253 // Compute how much of the 4x1 block of destination 16-bit values that
2254 // we have computed, fit in the destination matrix. Typically, all of
2255 // it fits, but when the destination matrix shape is not a multiple
2256 // of 4x1, there are some 4x1 blocks along the boundaries that do
2257 // not fit entirely.
2258
2259 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2260 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2261 "sub r1, r1, r8\n"
2262
2263 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2264 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2265 "sub r2, r2, r4\n"
2266 "mov r3, #4\n"
2267 "mov r5, #2\n"
2268 "cmp r1, #4\n"
2269 // Compute r1 = how many rows of the 4x1 block fit
2270 "it gt\n"
2271 "movgt r1, r3\n"
2272
2273 // Test if r1==4, i.e. if all of the 4x1 block fits.
2274 "cmp r1, r3\n"
2275
2276 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2277 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2278 // Yes, all of the 4x1 block fits, go to fast path.
2279 "beq 30f\n"
2280 // Not all of the 4x1 block fits.
2281 // Store to dst_tmp_buf
2282 // Set r3 address to write to dst_tmp_buf.
2283 "mov r3, %[dst_tmp_buf]\n"
2284 "vst1.16 {d28}, [r3]\n"
2285
2286 // Slow loop copying from dst_tmp_buf to dst.
2287 "50:\n"
2288 "mov r8, #0\n"
2289 "51:\n"
2290 // Shift of offset register for half-word loads not allowed in A32,
2291 // so we shift, load/store, then shift back r8.
2292 "lsl r8, r8, #1\n"
2293 "ldrh r10, [r3, r8]\n"
2294 "strh r10, [r4, r8]\n"
2295 "lsr r8, r8, #1\n"
2296 "add r8, r8, #1\n"
2297 "cmp r8, r1\n"
2298 "blt 51b\n"
2299 "b 31f\n"
2300 "30:\n"
2301 // Yes, all of the 4x1 block fits.
2302 // r3 address, r5 stride
2303 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2304 "mov r4, r3\n"
2305 "mov r6, #2\n"
2306
2307 "vst1.16 {d28[0]}, [r3], r6\n"
2308 "vst1.16 {d28[1]}, [r3], r6\n"
2309 "vst1.16 {d28[2]}, [r3], r6\n"
2310 "vst1.16 {d28[3]}, [r3], r6\n"
2311 "31:\n"
2312
2313 // Load dst_ptr, increment, and write back.
2314 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2315 "add r4, r4, #8\n"
2316 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2317
2318 RUY_MAKE_ZERO(q14)
2319
2320 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2321
2322 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2323
2324 // Since the store type is the same as the accum type, no need for
2325 // downcast. There's also no need for clamp by min/max.
2326
2327 // At this point, v20 -- v31 aren't used anymore for the current block,
2328 // so we can start clearing these accumulators for the next block
2329 // (next iteration of the main loop).
2330 // Clear accumulators.
2331 RUY_MAKE_ZERO(q6)
2332 RUY_MAKE_ZERO(q7)
2333 RUY_MAKE_ZERO(q8)
2334 RUY_MAKE_ZERO(q9)
2335 RUY_MAKE_ZERO(q10)
2336 RUY_MAKE_ZERO(q11)
2337 RUY_MAKE_ZERO(q12)
2338 RUY_MAKE_ZERO(q13)
2339
2340 // Compute how much of the 4x1 block of destination 32 bit values that
2341 // we have computed, fit in the destination matrix. Typically, all of
2342 // it fits, but when the destination matrix shape is not a multiple
2343 // of 4x2, there are some 4x4 blocks along the boundaries that do
2344 // not fit entirely.
2345
2346 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2347 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2348 "sub r1, r1, r8\n"
2349
2350 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2351 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2352 "sub r2, r2, r4\n"
2353 "mov r3, #4\n"
2354 "mov r5, #2\n"
2355 "cmp r1, #4\n"
2356 // Compute r1 = how many rows of the 4x2 block fit
2357 "it gt\n"
2358 "movgt r1, r3\n"
2359
2360 // Test if r1==4, i.e. if all of the 4x1 block fits.
2361 "cmp r1, r3\n"
2362
2363 // Yes, all of the 4x1 block fits, go to fast path.
2364 "beq 30f\n"
2365 // Not all of the 4x1 block fits.
2366 // Set (r3 address, r4 stride) to write to dst_tmp_buf
2367 "mov r3, %[dst_tmp_buf]\n"
2368 "mov r4, #16\n"
2369 "b 31f\n"
2370
2371 "30:\n"
2372 // Yes, all of the 4x1 block fits.
2373 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2374 // r3 address, r4 stride
2375 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2376 "mov r4, r5\n"
2377
2378 "31:\n"
2379
2380 "vst1.32 {d28, d29}, [r3]\n"
2381
2382 // If all of the 4x1 block fits, we just finished writing it to the
2383 // destination, so we skip the next part.
2384 "beq 41f\n"
2385 // Not all of the 4x1 block fits in the destination matrix. We just
2386 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
2387 // it to copy into the destination matrix the part that fits.
2388 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2389 "mov r3, %[dst_tmp_buf]\n"
2390 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2391 "50:\n"
2392 "mov r5, #0\n"
2393 "51:\n"
2394 "ldr r10, [r3, r5, lsl #2]\n"
2395 "str r10, [r4, r5, lsl #2]\n"
2396 "add r5, r5, #1\n"
2397 "cmp r5, r1\n"
2398 "blt 51b\n"
2399
2400 "41:\n"
2401 // Load dst_ptr, increment, and write back.
2402 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2403 "add r4, r4, #16\n"
2404 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2405
2406 RUY_MAKE_ZERO(q10)
2407 RUY_MAKE_ZERO(q11)
2408
2409 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2410
2411 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2412
2413 // Reload some params --- we had used x5 -- x7 for a few other things
2414 // since the last time we had loaded them.
2415 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2416 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2417 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2418
2419 // Move to the next block of the destination matrix, for the next iter
2420 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
2421 // been updated earlier.
2422 // Have we reached the end row?
2423 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2424 "cmp r8, r3\n"
2425
2426 "beq 20f\n" // yes, end row.
2427 // Not end row. Move to the next row.
2428 "add r8, r8, #4\n"
2429 // Store new value of row
2430 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2431
2432 "b 21f\n"
2433 "20:\n"
2434 // Was already at end row.
2435 // Move back to first row.
2436 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2437 // Move to the next column.
2438 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2439 "add r4, r4, #2\n"
2440 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2441
2442 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2443 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2444 // Increment dst_col_ptr by dst_stride (i.e. 1 column)
2445 "add r1, r1, r8\n"
2446 // Store dst_col_ptr
2447 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2448 // Store dst_ptr
2449 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2450 "21:\n"
2451
2452 // Main loop exit condition: have we hit the end column?
2453 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2454 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2455 "cmp r8, r4\n"
2456
2457 // w1 is the number of levels of depth that we have already loaded
2458 // LHS and RHS data for. Corresponding to the initial ld1 instructions
2459 // above, this is currently 16.
2460 "mov r1, #16\n"
2461
2462 "ble 1b\n"
2463
2464 // Restore stack pointer.
2465 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
2466
2467 // clang-format on
2468
2469 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
2470 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
2471 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
2472 // Clobber list must specify q registers (and not their constituent
2473 // d registers). There is a (currently unexplained) slowdown if
2474 // d registers are listed in the clobbers list.
2475 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
2476 "q9", "q10", "q12", "q13", "q14", "q15");
2477 }
2478
2479 #undef RUY_OFFSET_BIAS
2480 #undef RUY_OFFSET_LHS_SUMS
2481 #undef RUY_OFFSET_RHS_SUMS
2482 #undef RUY_OFFSET_LHS_BASE_PTR
2483 #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
2484 #undef RUY_OFFSET_MULTIPLIER_EXPONENT
2485 #undef RUY_OFFSET_RHS_BASE_PTR
2486 #undef RUY_OFFSET_DST_BASE_PTR
2487 #undef RUY_OFFSET_LHS_ZERO_POINT
2488 #undef RUY_OFFSET_RHS_ZERO_POINT
2489 #undef RUY_OFFSET_DST_ZERO_POINT
2490 #undef RUY_OFFSET_PROD_ZP_DEPTH
2491 #undef RUY_OFFSET_START_ROW
2492 #undef RUY_OFFSET_START_COL
2493 #undef RUY_OFFSET_LAST_ROW
2494 #undef RUY_OFFSET_LAST_COL
2495 #undef RUY_OFFSET_DST_ROWS
2496 #undef RUY_OFFSET_DST_COLS
2497 #undef RUY_OFFSET_LHS_STRIDE
2498 #undef RUY_OFFSET_RHS_STRIDE
2499 #undef RUY_OFFSET_DST_STRIDE
2500 #undef RUY_OFFSET_DEPTH
2501 #undef RUY_OFFSET_CLAMP_MIN
2502 #undef RUY_OFFSET_CLAMP_MAX
2503 #undef RUY_OFFSET_FLAGS
2504 #undef RUY_OFFSET_DST_TYPE_ID
2505
2506 #undef RUY_STACK_OFFSET_SIZE
2507 #undef RUY_STACK_OFFSET_DST_COL_PTR
2508 #undef RUY_STACK_OFFSET_DST_PTR
2509 #undef RUY_STACK_OFFSET_ROW
2510 #undef RUY_STACK_OFFSET_COL
2511 #undef RUY_STACK_OFFSET_LHS_COL_PTR
2512 #undef RUY_STACK_OFFSET_RHS_COL_PTR
2513
2514 #endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM)
2515 } // namespace ruy
2516