1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/experimental/ruy/kernel.h"
17 #include "tensorflow/lite/experimental/ruy/opt_set.h"
18 #include "tensorflow/lite/experimental/ruy/platform.h"
19 #include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h"
20
21 namespace ruy {
22
23 #if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
24
25 #define RUY_ASM_LABEL_STORE_UINT8 91
26 #define RUY_ASM_LABEL_STORE_INT8 92
27 #define RUY_ASM_LABEL_STORE_INT16 93
28 #define RUY_ASM_LABEL_STORE_INT32 94
29 #define RUY_ASM_LABEL_AFTER_STORE 99
30
31 #define RUY_OFFSET_LHS_BASE_PTR 0
32 #define RUY_OFFSET_RHS_BASE_PTR 4
33 #define RUY_OFFSET_DST_BASE_PTR 8
34 #define RUY_OFFSET_BIAS 12
35 #define RUY_OFFSET_START_ROW 16
36 #define RUY_OFFSET_START_COL 20
37 #define RUY_OFFSET_LAST_ROW 24
38 #define RUY_OFFSET_LAST_COL 28
39 #define RUY_OFFSET_DST_ROWS 32
40 #define RUY_OFFSET_DST_COLS 36
41 #define RUY_OFFSET_LHS_STRIDE 40
42 #define RUY_OFFSET_RHS_STRIDE 44
43 #define RUY_OFFSET_DST_STRIDE 48
44 #define RUY_OFFSET_DEPTH 52
45 #define RUY_OFFSET_CLAMP_MIN 56
46 #define RUY_OFFSET_CLAMP_MAX 60
47 #define RUY_OFFSET_FLAGS 64
48
49 #define RUY_STACK_OFFSET_SIZE 96
50 #define RUY_STACK_OFFSET_DST_COL_PTR 0
51 #define RUY_STACK_OFFSET_DST_PTR 16
52 #define RUY_STACK_OFFSET_ROW 32
53 #define RUY_STACK_OFFSET_COL 48
54 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
55 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
56
57 template <typename Params>
CheckOffsetsInKernelParamsFloat32(const Params &)58 void CheckOffsetsInKernelParamsFloat32(const Params&) {
59 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
60 static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
61 static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
62 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
63 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
64 static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
65 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
66 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
67 static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
68 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
69 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
70 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
71 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
72 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
73 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
74 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
75 }
76
77 // Float kernel for ARM32 out-of-order cores.
78 // Just like Float 64 version, except accumulate in to 8x4 block to only
79 // use 16 128-bit NEON registers. This is a "first pass" kernel and not
80 // tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8,4> & params)81 void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) {
82 CheckOffsetsInKernelParamsFloat32(params);
83 profiler::ScopeLabel label(
84 "Kernel (kNeon, optimized for out-of-order cores)");
85
86 const float* lhs_ptr = params.lhs_base_ptr;
87 const float* rhs_ptr = params.rhs_base_ptr;
88 // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
89 // each composed of two 64-bit "d" registers. The asm kernel below has the
90 // following NEON register allocation:
91 // Registers q3 -- q10 are accumulators. During accumulation,
92 // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
93 // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
94 // of RHS, like this:
95
96 // Register layout in "q" registers:
97 // RHS 1x4 block
98 // /--------------------------\
99 // |q2.s[0] ... q2.s[3] |
100 // \--------------------------/
101 // LHS 8x1 block
102 // /---------------------\ /--------------------- \
103 // | q0.s[0] | | q3.s[0] ... q9.s[0] |
104 // | ... | | ... ... |
105 // | q0.s[3] | | q3.s[3] q9.s[3] |
106 // | q1.s[0] | | q4.s[0] q10.s[0] |
107 // | ... | | ... ... ... |
108 // | q1.s[3] | | q4.s[3] .. q10.s[3] |
109 // \---------------------/ \--------------------------/
110 // accumulators 8x4 block
111 // q11, q14, q15 currently unused. q12 and q13 are used to load
112 // parameters used for the post-accumulation part of the kernel.
113 // For completeness, here is the register layout in "d" registers:
114 // RHS 1x4 block
115 // /--------------------------\
116 // |d4[0] ... d5[1] |
117 // \--------------------------/
118 // LHS 8x1 block
119 // /---------------------\ /--------------------------\
120 // | d0[0] | | d6[0] ... d18[0] |
121 // | ... | | ... ... |
122 // | d1[1] | | d7[1] d19[1] |
123 // | d2[0] | | d8[0] d20[0] |
124 // | ... | | ... ... ... |
125 // | d3[1] | | d9[1] ... d21[1] |
126 // \---------------------/ \--------------------------/
127 // accumulators 8x4 block
128 asm volatile(
129 #define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
130
131 // clang-format off
132
133 // Load the first 32 bytes of LHS and RHS data.
134 // Load q0, q1
135 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
136 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
137 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
138 // Load q2
139 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
140 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
141
142 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
143
144 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
145 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
146
147 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
148 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
149
150 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
151 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
152
153 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
154 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
155
156 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
157 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
158
159 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
160 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
161 // Clear accumulators.
162 RUY_MAKE_ZERO(q3)
163 RUY_MAKE_ZERO(q4)
164 RUY_MAKE_ZERO(q5)
165 RUY_MAKE_ZERO(q6)
166 RUY_MAKE_ZERO(q7)
167 RUY_MAKE_ZERO(q8)
168 RUY_MAKE_ZERO(q9)
169 RUY_MAKE_ZERO(q10)
170
171 // r1 is the number of levels of depth that we have already loaded
172 // LHS and RHS data for. Corresponding to the initial ld1 instructions
173 // above, this is currently 1.
174 "mov r1, #1\n"
175
176 // Main loop of the whole GEMM, over rows and columns of the
177 // destination matrix.
178 "1:\n"
179
180 // Accumulation loop
181 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
182 "cmp r1, r2\n"
183 "beq 79f\n"
184
185 "2:\n"
186
187 "vmla.f32 q3, q0, d4[0]\n"
188 "vmla.f32 q5, q0, d4[1]\n"
189 "vmla.f32 q7, q0, d5[0]\n"
190 "vmla.f32 q9, q0, d5[1]\n"
191 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
192
193 "vmla.f32 q4, q1, d4[0]\n"
194 "vmla.f32 q6, q1, d4[1]\n"
195 "vmla.f32 q8, q1, d5[0]\n"
196 "vmla.f32 q10, q1, d5[1]\n"
197 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
198 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
199 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
200 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
201
202 "add r1, r1, #1\n"
203 "cmp r1, r2\n"
204
205 "blt 2b\n"
206
207 "79:\n"
208
209 // End of the inner loop on depth. Now perform the remaining
210 // multiply-adds of the last level of depth, for which the LHS
211 // and RHS data is already loaded.
212
213 "vmla.f32 q3, q0, d4[0]\n"
214 "vmla.f32 q5, q0, d4[1]\n"
215 "vmla.f32 q7, q0, d5[0]\n"
216 "vmla.f32 q9, q0, d5[1]\n"
217
218 "vmla.f32 q4, q1, d4[0]\n"
219 "vmla.f32 q6, q1, d4[1]\n"
220 "vmla.f32 q8, q1, d5[0]\n"
221 "vmla.f32 q10, q1, d5[1]\n"
222
223 // End of accumulation. The registers q3 -- q10 contain the final
224 // float32 accumulator values of the current 8x8 destination block.
225 // We now have to compute the final values from these accumulators
226 // and advance to the next 8x8 block. We intertwine
227 // these two aspects whenever possible for optimal pipelining, both
228 // at the data flow level (prefetch data for next block as early as
229 // possible) and instruction pipelining level (some of the next-block
230 // work can dual-issue with some of the final work on the current
231 // block).
232
233 // Logic to advance to the next block in preparation for the next
234 // iteration of the main loop. For now, we only want to compute
235 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
236 // not yet ready to update the values of row and col, as we still need
237 // the current values for the rest of the work on the current block.
238
239 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
240 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
241 "cmp r1, r3\n" // Have we finished the last row?
242
243 "bge 4f\n" // If finished last row, go to 4
244 // Not finished last row: then advance to next row.
245 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
246 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
247 "add r4, r4, r1, lsl #3\n"
248 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
249 "b 5f\n"
250 "4:\n" // Finished last row...
251 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
252 // Go back to first row
253 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
254 // Now we need to advance to the next column. If we already
255 // finished the last column, then in principle we are done, however
256 // we can't just return here, as we need to allow the end work of the
257 // current block to complete. The good news is that at this point it
258 // doesn't matter what data we load for the next column, since
259 // we will exit from the main loop below before actually storing
260 // anything computed from that data.
261 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
262 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
263 "cmp r8, r4\n" // Have we finished the last column?
264 "bge 5f\n" // If yes, just carry on without updating the column pointer.
265 // Not finished last column: then advance to next column.
266 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
267 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
268 "add r10, r10, r1, lsl #2\n"
269 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
270 "5:\n"
271
272 // Set the LHS and RHS data pointers to the start of the columns just
273 // computed.
274 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
275 "mov %[lhs_ptr], r4\n"
276 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
277 "mov %[rhs_ptr], r5\n"
278
279 // Load some parameters needed for the end work on current block.
280 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
281 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
282
283 // Offset these base pointers as needed given the current row, col.
284 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
285 "add r5, r1, r8, lsl #2\n"
286
287 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
288 "it ne\n"
289 "movne r1, r5\n"
290
291 // Load 8 bias values.
292 "vld1.32 {d24, d25, d26, d27}, [r1]\n"
293
294 // Now that we know what LHS and RHS data the next iteration of the
295 // main loop will need to load, we start loading the first 32 bytes of
296 // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
297 // in the rest of the work on the current block.
298 // Load q0, q1
299 "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
300 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
301 // Load q2
302 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
303 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
304
305 // Perform the bias-addition (per the above, we have just folded into
306 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
307 "vadd.f32 q3, q3, q12\n"
308 "vadd.f32 q4, q4, q13\n"
309 "vadd.f32 q5, q5, q12\n"
310 "vadd.f32 q6, q6, q13\n"
311 "vadd.f32 q7, q7, q12\n"
312 "vadd.f32 q8, q8, q13\n"
313 "vadd.f32 q9, q9, q12\n"
314 "vadd.f32 q10, q10, q13\n"
315
316 // Load the clamp_min, clamp_max bounds
317 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
318 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
319 "vdup.32 q12, r2\n" // clamp_min
320 "vdup.32 q13, r3\n" // clamp_max
321
322 // Apply the clamp_min bound
323 "vmax.f32 q3, q3, q12\n"
324 "vmax.f32 q4, q4, q12\n"
325 "vmax.f32 q5, q5, q12\n"
326 "vmax.f32 q6, q6, q12\n"
327 "vmax.f32 q7, q7, q12\n"
328 "vmax.f32 q8, q8, q12\n"
329 "vmax.f32 q9, q9, q12\n"
330 "vmax.f32 q10, q10, q12\n"
331
332 // Apply the clamp_max bound
333 "vmin.f32 q3, q3, q13\n"
334 "vmin.f32 q4, q4, q13\n"
335 "vmin.f32 q5, q5, q13\n"
336 "vmin.f32 q6, q6, q13\n"
337 "vmin.f32 q7, q7, q13\n"
338 "vmin.f32 q8, q8, q13\n"
339 "vmin.f32 q9, q9, q13\n"
340 "vmin.f32 q10, q10, q13\n"
341
342 // Compute how much of the 8x4 block of destination values that
343 // we have computed, fit in the destination matrix. Typically, all of
344 // it fits, but when the destination matrix shape is not a multiple
345 // of 8x4, there are some 8x8 blocks along the boundaries that do
346 // not fit entirely.
347 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
348 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
349 "sub r1, r1, r8\n"
350
351 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
352 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
353 "sub r2, r2, r4\n"
354 "mov r3, #8\n"
355 "mov r5, #4\n"
356 "cmp r1, #8\n"
357 // Compute r1 = how many rows of the 8x4 block fit
358 "it gt\n"
359 "movgt r1, r3\n"
360 "cmp r2, #4\n"
361 // Compute r2 = how many cols of the 8x4 block fit
362 "it gt\n"
363 "movgt r2, r5\n"
364
365 // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
366 "cmp r1, r3\n"
367 "it eq\n"
368 "cmpeq r2, r5\n"
369 // Yes, all of the 8x4 block fits, go to fast path.
370 "beq 30f\n"
371 // Not all of the 8x4 block fits.
372 // Set (r3 address, r4 stride) to write to dst_tmp_buf
373 "mov r3, %[dst_tmp_buf]\n"
374 "mov r4, #32\n"
375 "b 31f\n"
376 "30:\n"
377 // Yes, all of the 8x4 block fits.
378 // Set (r3 address, r4 stride) to write directly to destination matrix.
379 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
380 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
381 "mov r4, r5\n"
382 "31:\n"
383
384 // Write our float values to the destination described by
385 // (r3 address, r4 stride)
386 "vst1.32 {d6, d7, d8, d9}, [r3]\n"
387 "add r3, r3, r4\n"
388 RUY_MAKE_ZERO(q3)
389 RUY_MAKE_ZERO(q4)
390 "vst1.32 {d10, d11, d12, d13}, [r3]\n"
391 "add r3, r3, r4\n"
392 RUY_MAKE_ZERO(q5)
393 RUY_MAKE_ZERO(q6)
394 "vst1.32 {d14, d15, d16, d17}, [r3]\n"
395 "add r3, r3, r4\n"
396 RUY_MAKE_ZERO(q7)
397 RUY_MAKE_ZERO(q8)
398 "vst1.32 {d18, d19, d20, d21}, [r3]\n"
399 "add r3, r3, r4\n"
400 RUY_MAKE_ZERO(q9)
401 RUY_MAKE_ZERO(q10)
402
403 // If all of the 8x4 block fits, we just finished writing it to the
404 // destination, so we skip the next part.
405 "beq 41f\n"
406 // Not all of the 8x8 block fits in the destination matrix. We just
407 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
408 // it to copy into the destination matrix the part that fits.
409 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
410 "mov r3, %[dst_tmp_buf]\n"
411 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
412 "mov r6, #0\n"
413 "50:\n"
414 "mov r5, #0\n"
415 "51:\n"
416 "ldr r10, [r3, r5, lsl #2]\n"
417 "str r10, [r4, r5, lsl #2]\n"
418 "add r5, r5, #1\n"
419 "cmp r5, r1\n"
420 "blt 51b\n"
421 "add r6, r6, #1\n"
422 "add r3, r3, #32\n"
423 "add r4, r4, r8\n"
424 // r2 = how many cols of the 8x4 block fit
425 "cmp r6, r2\n"
426 "blt 50b\n"
427 "41:\n"
428 // Load dst_ptr, increment, and write back.
429 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
430 "add r4, r4, #32\n"
431 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
432 // At this point we have completely finished writing values to the
433 // destination matrix for the current block.
434
435 // Reload some params --- we had used r3, r5, r10 for a few other things
436 // since the last time we had loaded them.
437 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
438 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
439 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
440
441 // Move to the next block of the destination matrix, for the next iter
442 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
443 // been updated earlier.
444 // Have we reached the end row?
445 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
446 "cmp r8, r3\n"
447
448 "beq 20f\n" // yes, end row.
449 // Not end row. Move to the next row.
450 "add r8, r8, #8\n"
451 // Store new value of row
452 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
453
454 "b 21f\n"
455 "20:\n"
456 // Was already at end row.
457 // Move back to first row.
458 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
459 // Move to the next column.
460 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
461 "add r4, r4, #4\n"
462 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
463
464 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
465 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
466 // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
467 "add r1, r1, r8, lsl #2\n"
468 // Store dst_col_ptr
469 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
470 // Store dst_ptr
471 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
472 "21:\n"
473
474 // Main loop exit condition: have we hit the end column?
475 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
476 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
477 "cmp r8, r4\n"
478
479 // r1 is the number of levels of depth that we have already loaded
480 // LHS and RHS data for. Corresponding to the initial ld1 instructions
481 // above, this is currently 1.
482 "mov r1, #1\n"
483
484 "ble 1b\n"
485
486 // Restore stack pointer.
487 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
488
489 // clang-format on
490 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
491 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
492 // Clobber list must specify q registers (and not their constituent
493 // d registers). There is a (currently unexplained) slowdown if
494 // d registers are listed in the clobbers list.
495 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
496 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
497 "q9", "q10", "q12", "q13");
498 }
499
500 #undef RUY_MAKE_ZERO
501 #undef RUY_STACK_OFFSET_SIZE
502 #undef RUY_STACK_OFFSET_DST_COL_PTR
503 #undef RUY_STACK_OFFSET_DST_PTR
504 #undef RUY_STACK_OFFSET_ROW
505 #undef RUY_STACK_OFFSET_COL
506 #undef RUY_STACK_OFFSET_LHS_COL_PTR
507 #undef RUY_STACK_OFFSET_RHS_COL_PTR
508
509 #undef RUY_OFFSET_LHS_BASE_PTR
510 #undef RUY_OFFSET_RHS_BASE_PTR
511 #undef RUY_OFFSET_DST_BASE_PTR
512 #undef RUY_OFFSET_BIAS
513 #undef RUY_OFFSET_START_ROW
514 #undef RUY_OFFSET_START_COL
515 #undef RUY_OFFSET_LAST_ROW
516 #undef RUY_OFFSET_LAST_COL
517 #undef RUY_OFFSET_DST_ROWS
518 #undef RUY_OFFSET_DST_COLS
519 #undef RUY_OFFSET_LHS_STRIDE
520 #undef RUY_OFFSET_RHS_STRIDE
521 #undef RUY_OFFSET_DST_STRIDE
522 #undef RUY_OFFSET_DEPTH
523 #undef RUY_OFFSET_CLAMP_MIN
524 #undef RUY_OFFSET_CLAMP_MAX
525 #undef RUY_OFFSET_FLAGS
526
527 #define RUY_OFFSET_BIAS 0
528 #define RUY_OFFSET_LHS_SUMS 4
529 #define RUY_OFFSET_RHS_SUMS 8
530 #define RUY_OFFSET_LHS_BASE_PTR 12
531 #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
532 #define RUY_OFFSET_MULTIPLIER_EXPONENT 20
533 #define RUY_OFFSET_RHS_BASE_PTR 24
534 #define RUY_OFFSET_DST_BASE_PTR 28
535 #define RUY_OFFSET_LHS_ZERO_POINT 32
536 #define RUY_OFFSET_RHS_ZERO_POINT 36
537 #define RUY_OFFSET_DST_ZERO_POINT 40
538 #define RUY_OFFSET_PROD_ZP_DEPTH 44
539 #define RUY_OFFSET_START_ROW 48
540 #define RUY_OFFSET_START_COL 52
541 #define RUY_OFFSET_LAST_ROW 56
542 #define RUY_OFFSET_LAST_COL 60
543 #define RUY_OFFSET_DST_ROWS 64
544 #define RUY_OFFSET_DST_COLS 68
545 #define RUY_OFFSET_LHS_STRIDE 72
546 #define RUY_OFFSET_RHS_STRIDE 76
547 #define RUY_OFFSET_DST_STRIDE 80
548 #define RUY_OFFSET_DEPTH 84
549 #define RUY_OFFSET_CLAMP_MIN 88
550 #define RUY_OFFSET_CLAMP_MAX 92
551 #define RUY_OFFSET_FLAGS 96
552 #define RUY_OFFSET_DST_TYPE_ID 97
553
554 #define RUY_STACK_OFFSET_SIZE 96
555 #define RUY_STACK_OFFSET_DST_COL_PTR 0
556 #define RUY_STACK_OFFSET_DST_PTR 16
557 #define RUY_STACK_OFFSET_ROW 32
558 #define RUY_STACK_OFFSET_COL 48
559 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
560 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
561
562 template <typename Params>
CheckOffsetsInKernelParams8bit(const Params &)563 void CheckOffsetsInKernelParams8bit(const Params&) {
564 static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
565 "");
566 static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
567 "");
568 static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
569 "");
570 static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
571 "");
572 static_assert(offsetof(Params, multiplier_fixedpoint) ==
573 RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
574 "");
575 static_assert(
576 offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
577 "");
578 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
579 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
580 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
581 static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
582 static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
583 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
584 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
585 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
586 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
587 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
588 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
589 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
590 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
591 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
592 }
593
594 // Fast-int8 kernel, ported from ARM 64 version.
595 // Relevant target CPUs for this kernel include Krait 400 and A9,
596 // since these are 32-bit, out-of-order CPUs.
Kernel8bitNeonOutOfOrder(const KernelParams8bit<4,2> & params)597 void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {
598 profiler::ScopeLabel label(
599 "Kernel (kNeon, optimized for out-of-order cores)");
600
601 CheckOffsetsInKernelParams8bit(params);
602
603 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
604 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
605 const std::int8_t* lhs_ptr = lhs_col_ptr;
606 const std::int8_t* rhs_ptr = rhs_col_ptr;
607
608 // The asm kernel below has the following NEON register allocation:
609 //
610 // q6 - q13 are 128-bit (4x32b) accumulators.
611 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
612 // d8 -- d11 from RHS:
613 // int8 RHS 16x2 block
614 // /-----------------------------\
615 // |d8.b[0-7] ..... d10.b[0-7]|
616 // | ... ... |
617 // |d9.b[0-7] ..... d11.b[0-7]|
618 // \-----------------------------/
619 // int8 LHS 4x16 block
620 // /------------------------\ /-----------------------------\
621 // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
622 // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
623 // (Reload d0, d1, d2, d3)
624 // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
625 // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
626 // \------------------------/ \-----------------------------/
627 // 128-bit accumulators 4x2 block
628 //
629 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
630 // optimization for this kernel.
631 asm volatile(
632 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
633
634 // clang-format off
635
636 // Load the first 64 bytes of LHS and RHS data.
637 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
638 // Clear accumulators.
639 RUY_MAKE_ZERO(q6)
640 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
641 RUY_MAKE_ZERO(q8)
642 RUY_MAKE_ZERO(q9)
643 RUY_MAKE_ZERO(q10)
644 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
645 RUY_MAKE_ZERO(q11)
646 "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
647
648 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
649
650 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
651 RUY_MAKE_ZERO(q12)
652 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
653
654 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
655 RUY_MAKE_ZERO(q13)
656 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
657
658 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
659 RUY_MAKE_ZERO(q14)
660 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
661
662 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
663 RUY_MAKE_ZERO(q15)
664 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
665
666 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
667 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
668
669 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
670 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
671
672
673 // r1 is the number of levels of depth that we have already loaded
674 // LHS and RHS data for. Corresponding to the initial ld1 instructions
675 // above, this is currently 16.
676 "mov r1, #16\n"
677
678 // Main loop of the whole GEMM, over rows and columns of the
679 // destination matrix.
680 "1:\n"
681
682 // r1 is how many levels of depth we have already loaded
683 // data for, r10 is the total depth.
684 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
685 "cmp r1, r10\n"
686 "beq 79f\n"
687
688 "2:\n"
689
690 // Mult, mult-acc in to q14, q15, q2, q3
691 "vmull.s8 q14, d0, d8\n"
692 "vmull.s8 q2, d0, d10\n"
693
694 "vmull.s8 q15, d2, d8\n"
695 "vmull.s8 q3, d2, d10\n"
696
697 "vmlal.s8 q14, d1, d9\n"
698 "vmlal.s8 q2, d1, d11\n"
699 "vmlal.s8 q15, d3, d9\n"
700 "vmlal.s8 q3, d3, d11\n"
701 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
702
703 // Then pairwise accumulate in to q6, q7, q10, q11
704 "vpadal.s16 q6, q14\n"
705 "vpadal.s16 q7, q15\n"
706 "vpadal.s16 q10, q2\n"
707 "vpadal.s16 q11, q3\n"
708
709 // Mult, mult-acc in to q14, q15, q2, q3
710 "vmull.s8 q14, d0, d8\n"
711 "vmull.s8 q2, d0, d10\n"
712
713 "vmull.s8 q15, d2, d8\n"
714 "vmull.s8 q3, d2, d10\n"
715
716 "vmlal.s8 q14, d1, d9\n"
717 "vmlal.s8 q2, d1, d11\n"
718 "vmlal.s8 q15, d3, d9\n"
719 "vmlal.s8 q3, d3, d11\n"
720 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
721
722 // Then pairwise accumulate in to q8, q9, q12, q13
723 "vpadal.s16 q8, q14\n"
724 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
725 "vpadal.s16 q9, q15\n"
726 "vpadal.s16 q12, q2\n"
727 "vpadal.s16 q13, q3\n"
728
729 // Prefetch the next 64 bytes of LHS and RHS data.
730 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
731 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
732
733 // Each iteration of this loop advances by 16 levels of depth.
734 "add r1, r1, #16\n"
735
736 // Loop termination condition
737 "cmp r1, r10\n"
738
739 "blt 2b\n"
740
741 "79:\n"
742
743 // Mult, mult-acc in to q14, q15, q2, q3
744 "vmull.s8 q14, d0, d8\n"
745 "vmull.s8 q2, d0, d10\n"
746
747 "vmull.s8 q15, d2, d8\n"
748 "vmull.s8 q3, d2, d10\n"
749
750 "vmlal.s8 q14, d1, d9\n"
751 "vmlal.s8 q2, d1, d11\n"
752 "vmlal.s8 q15, d3, d9\n"
753 "vmlal.s8 q3, d3, d11\n"
754 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
755
756 // Then pairwise accumulate in to q6, q7, q10, q11
757 "vpadal.s16 q6, q14\n"
758 "vpadal.s16 q7, q15\n"
759 "vpadal.s16 q10, q2\n"
760 "vpadal.s16 q11, q3\n"
761
762 // Mult, mult-acc in to q14, q15, q2, q3
763 "vmull.s8 q14, d0, d8\n"
764 "vmull.s8 q2, d0, d10\n"
765
766 "vmull.s8 q15, d2, d8\n"
767 "vmull.s8 q3, d2, d10\n"
768
769 "vmlal.s8 q14, d1, d9\n"
770 "vmlal.s8 q2, d1, d11\n"
771 "vmlal.s8 q15, d3, d9\n"
772 "vmlal.s8 q3, d3, d11\n"
773
774 // Then pairwise accumulate in to q8, q9, q12, q13
775 "vpadal.s16 q8, q14\n"
776 "vpadal.s16 q9, q15\n"
777 "vpadal.s16 q12, q2\n"
778 "vpadal.s16 q13, q3\n"
779
780
781 // All accumulation over depth done. q6 - q13 contain the 4x32b
782 // accumulators for the 4x2 final matrix.
783 // We now have to compute the final 8-bit values from these int32
784 // accumulators, and advance to the next 4x2 block. We intertwine
785 // these two aspects whenever possible for optimal pipelining, both
786 // at the data flow level (prefetch data for next block as early as
787 // possible) and instruction pipelining level (some of the next-block
788 // work can dual-issue with some of the final work on the current
789 // block).
790
791 // q6-q13 now contain 4 x 32b
792 "vpadd.i32 d0, d12, d13\n"
793 "vpadd.i32 d1, d14, d15\n"
794 "vpadd.i32 d2, d16, d17\n"
795 "vpadd.i32 d3, d18, d19\n"
796 "vpadd.i32 d4, d20, d21\n"
797 "vpadd.i32 d5, d22, d23\n"
798 "vpadd.i32 d6, d24, d25\n"
799 "vpadd.i32 d7, d26, d27\n"
800
801 // d0-d7 each contain 2 x 32b accumulators.
802 // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
803 // of destination, (Four 'd' registers total)
804 "vpadd.i32 d28, d0, d1\n"
805 "vpadd.i32 d29, d2, d3\n"
806 "vpadd.i32 d30, d4, d5\n"
807 "vpadd.i32 d31, d6, d7\n"
808
809 //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
810
811 // Logic to advance to the next block in preparation for the next
812 // iteration of the main loop. For now, we only want to compute
813 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
814 // not yet ready to update the values of row and col, as we still need
815 // the current values for the rest of the work on the current block.
816
817 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
818 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
819 "cmp r1, r3\n" // Have we finished the last row?
820
821 "bge 4f\n" // If finished last row, go to 4
822 // Not finished last row: then advance to next row.
823 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
824 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
825 "add r4, r4, r1, lsl #2\n"
826 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
827 "b 5f\n"
828 "4:\n" // Finished last row...
829 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
830 // Go back to first row
831 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
832
833 // Now we need to advance to the next column. If we already
834 // finished the last column, then in principle we are done, however
835 // we can't just return here, as we need to allow the end work of the
836 // current block to complete. The good news is that at this point it
837 // doesn't matter what data we load for the next column, since
838 // we will exit from the main loop below before actually storing
839 // anything computed from that data.
840
841 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
842 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
843 "cmp r8, r4\n" // Have we finished the last column?
844 "bge 5f\n" // If yes, just carry on without updating the column pointer.
845 // Not finished last column: then advance to next column.
846 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
847 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
848 "add r10, r10, r1, lsl #1\n"
849 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
850 "5:\n"
851
852 // Set the LHS and RHS data pointers to the start of the columns just
853 // computed.
854 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
855 "mov %[lhs_ptr], r4\n"
856 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
857 "mov %[rhs_ptr], r5\n"
858
859 // Now we load: bias data, LHS sums data, RHS sums data.
860
861 // First, load the base pointers from the params.
862 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
863 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
864
865 // Offset these base pointers as needed given the current row, col.
866 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
867 "add r5, r1, r8, lsl #2\n"
868
869 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
870 "it ne\n"
871 "movne r1, r5\n"
872
873 // Load 4 bias values.
874 "vld1.32 {d24, d25}, [r1]\n"
875
876 // Now that we know what LHS and RHS data the next iteration of the
877 // main loop will need to load, we start loading the first 32 bytes of
878 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
879 // in the rest of the work on the current block.
880 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
881 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
882 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
883 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
884
885 // Add to the bias values the product
886 // (depth * lhs_zero_point * rhs_zero_point),
887 // See the term NZ1Z2 in equation (7) in
888 // https://arxiv.org/pdf/1712.05877.pdf
889 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
890 "vdup.32 q9, r3\n"
891 "vadd.i32 q12, q12, q9\n"
892
893 // Perform the bias-addition (per the above, we have just folded into
894 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
895 "vadd.i32 q14, q14, q12\n"
896 "vadd.i32 q15, q15, q12\n"
897
898 // LHS/RHS zero points
899 // Has RHS sums
900 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
901 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
902 "beq 401f\n"
903 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
904 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
905 // Offset by current col * number of bytes per value
906 "add r3, r3, r4, lsl #2\n"
907 "vld1.32 { d12 }, [r3]\n"
908 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
909 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
910 // Subtract rhs_sums * lhs_zero_point, per
911 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
912 "vmls.i32 q14, q10, d12[0]\n"
913 "vmls.i32 q15, q10, d12[1]\n"
914 "401:\n"
915
916 // Has LHS sums
917 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
918 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
919 "beq 402f\n"
920 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
921 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
922 // Offset by current row * number of bytes per value
923 "add r2, r2, r4, lsl #2\n"
924 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
925
926 // Load 4 lhs_sums values.
927 "vld1.32 {d22, d23}, [r2]\n"
928 "vdup.32 d13, r5\n" // rhs_zero_point
929
930 // Compute lhs_sums * rhs_zero_point.
931 "vmul.i32 q11, q11, d13[1]\n"
932 // Subtract lhs_sums * rhs_zero_point, per
933 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
934 "vsub.s32 q14, q14, q11\n"
935 "vsub.s32 q15, q15, q11\n"
936
937 // If the destination is int32, it means the user asks for the raw
938 // accumulators, no need for us to downquantize the value.
939 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
940 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
941 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
942
943 "402:\n"
944
945 // At this point we have computed the final int32 values. Now we
946 // start down-quantizing them to obtain the final 8bit values from them.
947
948 // As part of this down-quantization, our int32 values will be
949 // multiplied by a multiplier that has a fixed-point component and an
950 // exponent component.
951
952 //Load the exponent part of the multiplier.
953 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
954 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
955 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
956 "add r5, r1, r4, lsl #2\n"
957 "it ne\n"
958 "movne r1, r5\n"
959
960 "vld1.32 {q10}, [r1]\n"
961
962 RUY_MAKE_ZERO(q8)
963 "vmax.s32 q12, q10, q8\n"
964
965 "vshl.s32 q14, q14, q12\n"
966 "vshl.s32 q15, q15, q12\n"
967
968 "vmin.s32 q12, q10, q8\n"
969
970 // Load fixed point part of the multiplier
971 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
972 // r6 has flags, r4 has row
973 "add r5, r1, r4, lsl #2\n"
974 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
975 "it ne\n"
976 "movne r1, r5\n"
977 "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
978
979 // Apply the fixed-point part of the multiplier.
980 "vqrdmulh.s32 q14, q14, q10\n"
981 "vqrdmulh.s32 q15, q15, q10\n"
982
983 // We have some rounding division-by-power-of-two to do. This should
984 // always use "round to nearest". We allow for some
985 // freedom in how ties are broken, to strike a good compromise of
986 // performance on given hardware vs. perfect agreement of results
987 // across hardware.
988 //
989 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
990 // defined tie-breaks to help performance. On NEON, this means that we
991 // can just use the NEON rounding instructions, such as srshl. They
992 // happen to be breaking ties upward.
993 //
994 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
995 // break-ties-away-from zero, as described in Appendix B of
996 // https://arxiv.org/pdf/1712.05877.pdf
997 // When we wrote that, we thought that that would be better unbiased
998 // than the NEON upwards tie-breaks, and we had observed some
999 // improvement on some model. However, that is only more unbiased for
1000 // data centered at zero, which was likely the case in that model,
1001 // but is not always the case. If we wanted something more consistently
1002 // unbiased then we should try breaking ties toward-nearest-even.
1003 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
1004 // Fix up values to be right-shifted, so that the (round to nearest,
1005 // break ties upward) behavior of srshl applied to these fixed-up
1006 // values, produces the same result as the desired (round to nearest,
1007 // break ties away from zero) behavior on the original values.
1008 "vand q8, q14, q12\n"
1009 "vand q9, q15, q12\n"
1010 "vshr.s32 q8, q8, #31\n"
1011 "vshr.s32 q9, q9, #31\n"
1012 "vqadd.s32 q14, q14, q8\n"
1013 "vqadd.s34 q15, q15, q9\n"
1014
1015 #endif
1016 // At this point we have reduced the problem of correctly implementing
1017 // rounding divide-by-power-of-two, to what the SRSHL instruction can
1018 // do.
1019 "vrshl.s32 q14, q14, q12\n"
1020 "vrshl.s32 q15, q15, q12\n"
1021
1022 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1023 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1024 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1025 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1026 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1027
1028 // Store uint8 values:
1029 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1030
1031 // Cast-and-saturate from int32 to int16
1032 // After this, all values for output are in q14.
1033 "vqmovn.s32 d28, q14\n"
1034 "vqmovn.s32 d29, q15\n"
1035
1036 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1037 // current block, so we can start clearing these accumulators for the
1038 // next block (next iteration of the main loop).
1039 RUY_MAKE_ZERO(q6)
1040 RUY_MAKE_ZERO(q7)
1041 RUY_MAKE_ZERO(q8)
1042 RUY_MAKE_ZERO(q9)
1043 RUY_MAKE_ZERO(q10)
1044 RUY_MAKE_ZERO(q11)
1045 RUY_MAKE_ZERO(q12)
1046 RUY_MAKE_ZERO(q13)
1047 RUY_MAKE_ZERO(q15)
1048
1049 // Load the destination zero point into each of the 8 16-bit slots
1050 // in a q register.
1051 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1052 "vdup.16 q13, r4\n" // dst_zero_point
1053
1054 // Add the destination zero point
1055 "vadd.i16 q14, q14, q13\n"
1056
1057 // Cast-and-saturate from int16 to uint8
1058 // Now all 8 1-byte values are in d30.
1059 "vqmovun.s16 d30, q14\n"
1060
1061 // Load the clamp_min, clamp_max bounds
1062 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1063 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1064 "vdup.8 d28, r2\n" // clamp_min
1065 "vdup.8 d29, r3\n" // clamp_max
1066
1067 // Apply the clamp_min bound
1068 "vmax.u8 d30, d30, d28\n"
1069 // Apply the clamp_max bound
1070 "vmin.u8 d30, d30, d29\n"
1071
1072 // Compute how much of the 4x2 block of destination 8bit values that
1073 // we have computed, fit in the destination matrix. Typically, all of
1074 // it fits, but when the destination matrix shape is not a multiple
1075 // of 4x2, there are some 4x2 blocks along the boundaries that do
1076 // not fit entirely.
1077
1078 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1079 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1080 "sub r1, r1, r8\n"
1081
1082 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1083 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1084 "sub r2, r2, r4\n"
1085 "mov r3, #4\n"
1086 "mov r5, #2\n"
1087 "cmp r1, #4\n"
1088 // Compute r1 = how many rows of the 4x2 block fit
1089 "it gt\n"
1090 "movgt r1, r3\n"
1091
1092 "cmp r2, #2\n"
1093 // Compute r2 = how many cols of the 4x2 block fit
1094 "it gt\n"
1095 "movgt r2, r5\n"
1096
1097 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1098 "cmp r1, r3\n"
1099 "it eq\n"
1100 "cmpeq r2, r5\n"
1101 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1102 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1103 // Yes, all of the 4x2 block fits, go to fast path.
1104 "beq 30f\n"
1105 // Not all of the 4x2 block fits.
1106 // Store to dst_tmp_buf
1107 // Set r3 address to write to dst_tmp_buf.
1108 "mov r3, %[dst_tmp_buf]\n"
1109 "vst1.8 {d30}, [r3]\n"
1110
1111 // Slow loop copying from dst_tmp_buf to dst.
1112 "mov r6, #0\n"
1113 "50:\n"
1114 "mov r8, #0\n"
1115 "51:\n"
1116 "ldrb r10, [r3, r8]\n"
1117 "strb r10, [r4, r8]\n"
1118 "add r8, r8, #1\n"
1119 "cmp r8, r1\n"
1120 "blt 51b\n"
1121 "add r6, r6, #1\n"
1122 "add r3, r3, #4\n"
1123 "add r4, r4, r5\n"
1124 "cmp r6, r2\n"
1125 "blt 50b\n"
1126 "b 31f\n"
1127 "30:\n"
1128 // Yes, all of the 4x2 block fits.
1129 // r3 address, r5 stride
1130 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1131 "mov r4, r3\n"
1132 "mov r6, #1\n"
1133
1134 "vst1.32 {d30[0]}, [r3]\n"
1135 "add r4, r4, r5\n"
1136 "mov r3, r4\n"
1137 "vst1.32 {d30[1]}, [r3]\n"
1138
1139 "31:\n"
1140
1141 // Load dst_ptr, increment, and write back.
1142 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1143 "add r4, r4, #4\n"
1144 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1145
1146 RUY_MAKE_ZERO(q13)
1147 RUY_MAKE_ZERO(q14)
1148 RUY_MAKE_ZERO(q15)
1149
1150 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1151
1152 // Store int8 values:
1153 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1154
1155 // Cast-and-saturate from int32 to int16
1156 // After this, all values for output are in q14.
1157 "vqmovn.s32 d28, q14\n"
1158 "vqmovn.s32 d29, q15\n"
1159
1160 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1161 // current block, so we can start clearing these accumulators for the
1162 // next block (next iteration of the main loop).
1163 RUY_MAKE_ZERO(q6)
1164 RUY_MAKE_ZERO(q7)
1165 RUY_MAKE_ZERO(q8)
1166 RUY_MAKE_ZERO(q9)
1167 RUY_MAKE_ZERO(q10)
1168 RUY_MAKE_ZERO(q11)
1169 RUY_MAKE_ZERO(q12)
1170 RUY_MAKE_ZERO(q13)
1171 RUY_MAKE_ZERO(q15)
1172
1173 // Load the destination zero point into each of the 8 16-bit slots
1174 // in a q register.
1175 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1176 "vdup.16 q13, r4\n" // dst_zero_point
1177
1178 // Add the destination zero point
1179 "vadd.i16 q14, q14, q13\n"
1180
1181 // Cast-and-saturate from int16 to int8
1182 // Now all 8 1-byte values are in d30.
1183 "vqmovn.s16 d30, q14\n"
1184
1185 // Load the clamp_min, clamp_max bounds
1186 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1187 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1188 "vdup.8 d28, r2\n" // clamp_min
1189 "vdup.8 d29, r3\n" // clamp_max
1190
1191 // Apply the clamp_min bound
1192 "vmax.s8 d30, d30, d28\n"
1193 // Apply the clamp_max bound
1194 "vmin.s8 d30, d30, d29\n"
1195
1196 // Compute how much of the 4x2 block of destination 8bit values that
1197 // we have computed, fit in the destination matrix. Typically, all of
1198 // it fits, but when the destination matrix shape is not a multiple
1199 // of 4x2, there are some 4x2 blocks along the boundaries that do
1200 // not fit entirely.
1201
1202 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1203 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1204 "sub r1, r1, r8\n"
1205
1206 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1207 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1208 "sub r2, r2, r4\n"
1209 "mov r3, #4\n"
1210 "mov r5, #2\n"
1211 "cmp r1, #4\n"
1212 // Compute r1 = how many rows of the 4x2 block fit
1213 "it gt\n"
1214 "movgt r1, r3\n"
1215
1216 "cmp r2, #2\n"
1217 // Compute r2 = how many cols of the 4x2 block fit
1218 "it gt\n"
1219 "movgt r2, r5\n"
1220
1221 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1222 "cmp r1, r3\n"
1223 "it eq\n"
1224 "cmpeq r2, r5\n"
1225 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1226 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1227 // Yes, all of the 4x2 block fits, go to fast path.
1228 "beq 30f\n"
1229 // Not all of the 4x2 block fits.
1230 // Store to dst_tmp_buf
1231 // Set r3 address to write to dst_tmp_buf.
1232 "mov r3, %[dst_tmp_buf]\n"
1233 "vst1.8 {d30}, [r3]\n"
1234
1235 // Slow loop copying from dst_tmp_buf to dst.
1236 "mov r6, #0\n"
1237 "50:\n"
1238 "mov r8, #0\n"
1239 "51:\n"
1240 "ldrb r10, [r3, r8]\n"
1241 "strb r10, [r4, r8]\n"
1242 "add r8, r8, #1\n"
1243 "cmp r8, r1\n"
1244 "blt 51b\n"
1245 "add r6, r6, #1\n"
1246 "add r3, r3, #4\n"
1247 "add r4, r4, r5\n"
1248 "cmp r6, r2\n"
1249 "blt 50b\n"
1250 "b 31f\n"
1251 "30:\n"
1252 // Yes, all of the 4x2 block fits.
1253 // r3 address, r5 stride
1254 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1255 "mov r4, r3\n"
1256 "mov r6, #1\n"
1257
1258 "vst1.32 {d30[0]}, [r3]\n"
1259 "add r4, r4, r5\n"
1260 "mov r3, r4\n"
1261 "vst1.32 {d30[1]}, [r3]\n"
1262
1263 "31:\n"
1264
1265 // Load dst_ptr, increment, and write back.
1266 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1267 "add r4, r4, #4\n"
1268 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1269
1270 RUY_MAKE_ZERO(q13)
1271 RUY_MAKE_ZERO(q14)
1272 RUY_MAKE_ZERO(q15)
1273
1274 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1275
1276 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1277
1278 // Load the destination zero point into each of the 4 32-bit slots
1279 // in a q register.
1280 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1281 "vdup.32 q13, r4\n" // dst_zero_point
1282 // Add the destination zero point
1283 "vadd.s32 q14, q14, q13\n"
1284 "vadd.s32 q15, q15, q13\n"
1285
1286 // Cast-and-saturate from int32 to int16
1287 // After this, all values for output are in q14.
1288 "vqmovn.s32 d28, q14\n"
1289 "vqmovn.s32 d29, q15\n"
1290
1291 // At this point, v18 -- v31 aren't used anymore for the current block,
1292 // so we can start clearing these accumulators for the next block
1293 // (next iteration of the main loop).
1294 RUY_MAKE_ZERO(q6)
1295 RUY_MAKE_ZERO(q7)
1296 RUY_MAKE_ZERO(q8)
1297 RUY_MAKE_ZERO(q9)
1298 RUY_MAKE_ZERO(q10)
1299 RUY_MAKE_ZERO(q11)
1300 RUY_MAKE_ZERO(q15)
1301
1302 // Load the clamp_min, clamp_max bounds
1303 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1304 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1305 "vdup.16 q12, r2\n" // clamp_min
1306 "vdup.16 q13, r3\n" // clamp_max
1307
1308 // Apply the clamp_min bound
1309 "vmax.s16 q14, q14, q12\n"
1310 // Apply the clamp_max bound
1311 "vmin.s16 q14, q14, q13\n"
1312
1313 RUY_MAKE_ZERO(q12)
1314 RUY_MAKE_ZERO(q13)
1315
1316 // Compute how much of the 4x2 block of destination 16-bit values that
1317 // we have computed, fit in the destination matrix. Typically, all of
1318 // it fits, but when the destination matrix shape is not a multiple
1319 // of 4x2, there are some 4x2 blocks along the boundaries that do
1320 // not fit entirely.
1321
1322 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1323 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1324 "sub r1, r1, r8\n"
1325
1326 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1327 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1328 "sub r2, r2, r4\n"
1329 "mov r3, #4\n"
1330 "mov r5, #2\n"
1331 "cmp r1, #4\n"
1332 // Compute r1 = how many rows of the 4x2 block fit
1333 "it gt\n"
1334 "movgt r1, r3\n"
1335
1336 "cmp r2, #2\n"
1337 // Compute r2 = how many cols of the 4x2 block fit
1338 "it gt\n"
1339 "movgt r2, r5\n"
1340
1341 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1342 "cmp r1, r3\n"
1343 "it eq\n"
1344 "cmpeq r2, r5\n"
1345 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1346 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1347 // Yes, all of the 4x2 block fits, go to fast path.
1348 "beq 30f\n"
1349 // Not all of the 4x2 block fits.
1350 // Store to dst_tmp_buf
1351 // Set r3 address to write to dst_tmp_buf.
1352 "mov r3, %[dst_tmp_buf]\n"
1353 "vst1.16 {q14}, [r3]\n"
1354
1355 // Slow loop copying from dst_tmp_buf to dst.
1356 "mov r6, #0\n"
1357 "50:\n"
1358 "mov r8, #0\n"
1359 "51:\n"
1360 // Shift of offset register for half-word loads not allowed in A32,
1361 // so we shift, load/store, then shift back r8.
1362 "lsl r8, r8, #1\n"
1363 "ldrh r10, [r3, r8]\n"
1364 "strh r10, [r4, r8]\n"
1365 "lsr r8, r8, #1\n"
1366 "add r8, r8, #1\n"
1367 "cmp r8, r1\n"
1368 "blt 51b\n"
1369 "add r6, r6, #1\n"
1370 "add r3, r3, #8\n"
1371 "add r4, r4, r5\n"
1372 "cmp r6, r2\n"
1373 "blt 50b\n"
1374 "b 31f\n"
1375 "30:\n"
1376 // Yes, all of the 4x2 block fits.
1377 // r3 address, r5 stride
1378 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1379 "mov r4, r3\n"
1380 "mov r6, #2\n"
1381
1382 "vst1.16 {d28[0]}, [r3], r6\n"
1383 "add r4, r4, r5\n"
1384 "vst1.16 {d28[1]}, [r3], r6\n"
1385 "vst1.16 {d28[2]}, [r3], r6\n"
1386 "vst1.16 {d28[3]}, [r3], r6\n"
1387 "mov r3, r4\n"
1388 "vst1.16 {d29[0]}, [r3], r6\n"
1389 "vst1.16 {d29[1]}, [r3], r6\n"
1390 "vst1.16 {d29[2]}, [r3], r6\n"
1391 "vst1.16 {d29[3]}, [r3], r6\n"
1392 "31:\n"
1393
1394 // Load dst_ptr, increment, and write back.
1395 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1396 "add r4, r4, #8\n"
1397 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1398
1399 RUY_MAKE_ZERO(q14)
1400
1401 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1402
1403 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1404
1405 // Since the store type is the same as the accum type, no need for
1406 // downcast. There's also no need for clamp by min/max.
1407
1408 // At this point, v20 -- v31 aren't used anymore for the current block,
1409 // so we can start clearing these accumulators for the next block
1410 // (next iteration of the main loop).
1411 // Clear accumulators.
1412 RUY_MAKE_ZERO(q6)
1413 RUY_MAKE_ZERO(q7)
1414 RUY_MAKE_ZERO(q8)
1415 RUY_MAKE_ZERO(q9)
1416 RUY_MAKE_ZERO(q10)
1417 RUY_MAKE_ZERO(q11)
1418 RUY_MAKE_ZERO(q12)
1419 RUY_MAKE_ZERO(q13)
1420
1421 // Compute how much of the 4x2 block of destination 32 bit values that
1422 // we have computed, fit in the destination matrix. Typically, all of
1423 // it fits, but when the destination matrix shape is not a multiple
1424 // of 4x2, there are some 4x4 blocks along the boundaries that do
1425 // not fit entirely.
1426
1427 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1428 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1429 "sub r1, r1, r8\n"
1430
1431 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1432 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1433 "sub r2, r2, r4\n"
1434 "mov r3, #4\n"
1435 "mov r5, #2\n"
1436 "cmp r1, #4\n"
1437 // Compute r1 = how many rows of the 4x2 block fit
1438 "it gt\n"
1439 "movgt r1, r3\n"
1440
1441 "cmp r2, #2\n"
1442 // Compute r2 = how many cols of the 4x2 block fit
1443 "it gt\n"
1444 "movgt r2, r5\n"
1445
1446 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1447 "cmp r1, r3\n"
1448 "it eq\n"
1449 "cmpeq r2, r5\n"
1450 // Yes, all of the 4x2 block fits, go to fast path.
1451 "beq 30f\n"
1452 // Not all of the 4x2 block fits.
1453 // Set (r3 address, r4 stride) to write to dst_tmp_buf
1454 "mov r3, %[dst_tmp_buf]\n"
1455 "mov r4, #16\n"
1456 "b 31f\n"
1457
1458 "30:\n"
1459 // Yes, all of the 4x2 block fits.
1460 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1461 // r3 address, r4 stride
1462 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1463 "mov r4, r5\n"
1464
1465 "31:\n"
1466
1467 "vst1.32 {d28, d29}, [r3]\n"
1468 "add r3, r3, r4\n"
1469 "vst1.32 {d30, d31}, [r3]\n"
1470
1471 // If all of the 4x2 block fits, we just finished writing it to the
1472 // destination, so we skip the next part.
1473 "beq 41f\n"
1474 // Not all of the 4x2 block fits in the destination matrix. We just
1475 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
1476 // it to copy into the destination matrix the part that fits.
1477 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1478 "mov r3, %[dst_tmp_buf]\n"
1479 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1480 "mov r6, #0\n"
1481 "50:\n"
1482 "mov r5, #0\n"
1483 "51:\n"
1484 "ldr r10, [r3, r5, lsl #2]\n"
1485 "str r10, [r4, r5, lsl #2]\n"
1486 "add r5, r5, #1\n"
1487 "cmp r5, r1\n"
1488 "blt 51b\n"
1489 "add r6, r6, #1\n"
1490 "add r3, r3, #16\n"
1491 "add r4, r4, r8\n"
1492 // r2 = how many cols of the 8x4 block fit
1493 "cmp r6, r2\n"
1494 "blt 50b\n"
1495
1496 "41:\n"
1497 // Load dst_ptr, increment, and write back.
1498 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1499 "add r4, r4, #16\n"
1500 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1501
1502 RUY_MAKE_ZERO(q10)
1503 RUY_MAKE_ZERO(q11)
1504
1505 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1506
1507 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1508
1509 // Reload some params --- we had used x5 -- x7 for a few other things
1510 // since the last time we had loaded them.
1511 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1512 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1513 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1514
1515 // Move to the next block of the destination matrix, for the next iter
1516 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1517 // been updated earlier.
1518 // Have we reached the end row?
1519 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1520 "cmp r8, r3\n"
1521
1522 "beq 20f\n" // yes, end row.
1523 // Not end row. Move to the next row.
1524 "add r8, r8, #4\n"
1525 // Store new value of row
1526 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1527
1528 "b 21f\n"
1529 "20:\n"
1530 // Was already at end row.
1531 // Move back to first row.
1532 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1533 // Move to the next column.
1534 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1535 "add r4, r4, #2\n"
1536 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1537
1538 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1539 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1540 // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
1541 "add r1, r1, r8, lsl #1\n"
1542 // Store dst_col_ptr
1543 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1544 // Store dst_ptr
1545 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1546 "21:\n"
1547
1548 // Main loop exit condition: have we hit the end column?
1549 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1550 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1551 "cmp r8, r4\n"
1552
1553 // w1 is the number of levels of depth that we have already loaded
1554 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1555 // above, this is currently 16.
1556 "mov r1, #16\n"
1557
1558 "ble 1b\n"
1559
1560 // Restore stack pointer.
1561 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1562
1563 // clang-format on
1564
1565 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
1566 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
1567 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
1568 // Clobber list must specify q registers (and not their constituent
1569 // d registers). There is a (currently unexplained) slowdown if
1570 // d registers are listed in the clobbers list.
1571 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
1572 "q9", "q10", "q12", "q13", "q14", "q15");
1573 }
1574
1575 // Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
1576 // is still packed as if it has two columns
Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4,2> & params)1577 void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) {
1578 profiler::ScopeLabel label(
1579 "Kernel (kNeon, optimized for out-of-order cores)");
1580
1581 CheckOffsetsInKernelParams8bit(params);
1582
1583 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1584 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1585 const std::int8_t* lhs_ptr = lhs_col_ptr;
1586 const std::int8_t* rhs_ptr = rhs_col_ptr;
1587
1588 // The asm kernel below has the following NEON register allocation:
1589 //
1590 // q6 - q13 are 128-bit (4x32b) accumulators.
1591 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
1592 // d8 -- d11 from RHS:
1593 // int8 RHS 16x1 block
1594 // /------------\
1595 // | d8.b[0] |
1596 // | ... |
1597 // | d8.b[7] |
1598 // | d9.b[0] |
1599 // | ... |
1600 // | d9.b[7] |
1601 // \------------/
1602 // int8 LHS 4x16 block
1603 // /-----------------------------------------\ /------------\
1604 // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
1605 // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
1606 // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
1607 // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
1608 // \-----------------------------------------/ \------------/
1609 // 128-bit accumulators 4x1 block
1610 //
1611 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1612 // optimization for this kernel.
1613 asm volatile(
1614 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
1615
1616 // clang-format off
1617
1618 // Load the first 64 bytes of LHS and RHS data.
1619 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1620 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1621 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1622 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1623 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1624 // Skip the other column and advance the pointer.
1625 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1626
1627 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1628
1629 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1630 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1631
1632 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1633 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1634
1635 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1636 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1637
1638 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
1639 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1640
1641 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1642 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1643
1644 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
1645 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1646
1647 // Clear accumulators.
1648 RUY_MAKE_ZERO(q6)
1649 RUY_MAKE_ZERO(q7)
1650 RUY_MAKE_ZERO(q8)
1651 RUY_MAKE_ZERO(q9)
1652 RUY_MAKE_ZERO(q10)
1653 RUY_MAKE_ZERO(q11)
1654 RUY_MAKE_ZERO(q12)
1655 RUY_MAKE_ZERO(q13)
1656 RUY_MAKE_ZERO(q14)
1657 RUY_MAKE_ZERO(q15)
1658
1659 // r1 is the number of levels of depth that we have already loaded
1660 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1661 // above, this is currently 16.
1662 "mov r1, #16\n"
1663
1664 // Main loop of the whole GEMM, over rows and columns of the
1665 // destination matrix.
1666 "1:\n"
1667
1668 // r1 is how many levels of depth we have already loaded
1669 // data for, r10 is the total depth.
1670 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1671 "cmp r1, r10\n"
1672 "beq 79f\n"
1673
1674 "2:\n"
1675
1676 // Mult, mult-acc in to q14, q15
1677 "vmull.s8 q14, d0, d8\n"
1678 "vmull.s8 q15, d2, d8\n"
1679 "vmlal.s8 q14, d1, d9\n"
1680 "vmlal.s8 q15, d3, d9\n"
1681
1682 // Then pairwise accumulate in to q6, q7
1683 "vpadal.s16 q6, q14\n"
1684 "vpadal.s16 q7, q15\n"
1685
1686 // Mult, mult-acc in to q14, q15
1687 "vmull.s8 q14, d4, d8\n"
1688 "vmull.s8 q15, d6, d8\n"
1689 "vmlal.s8 q14, d5, d9\n"
1690 "vmlal.s8 q15, d7, d9\n"
1691
1692 // Then pairwise accumulate in to q8, q9
1693 "vpadal.s16 q8, q14\n"
1694 "vpadal.s16 q9, q15\n"
1695
1696
1697 // Load the next 64 bytes of LHS and RHS data.
1698 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1699 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1700 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1701 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1702 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1703 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1704 // Skip the other column and advance the pointer.
1705 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1706 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1707
1708 // Each iteration of this loop advances by 16 levels of depth.
1709 "add r1, r1, #16\n"
1710
1711 // Loop termination condition
1712 "cmp r1, r10\n"
1713
1714 "blt 2b\n"
1715
1716 "79:\n"
1717
1718 // Mult, mult-acc in to q14, q15
1719 "vmull.s8 q14, d0, d8\n"
1720 "vmull.s8 q15, d2, d8\n"
1721 "vmlal.s8 q14, d1, d9\n"
1722 "vmlal.s8 q15, d3, d9\n"
1723
1724 // Then pairwise accumulate in to q6, q7
1725 "vpadal.s16 q6, q14\n"
1726 "vpadal.s16 q7, q15\n"
1727
1728 // Mult, mult-acc in to q14, q15
1729 "vmull.s8 q14, d4, d8\n"
1730 "vmull.s8 q15, d6, d8\n"
1731 "vmlal.s8 q14, d5, d9\n"
1732 "vmlal.s8 q15, d7, d9\n"
1733
1734 // Then pairwise accumulate in to q8, q9
1735 "vpadal.s16 q8, q14\n"
1736 "vpadal.s16 q9, q15\n"
1737
1738 // All accumulation over depth done. q6 - q9 contain the 4x32b
1739 // accumulators for the 4x1 final matrix.
1740 // We now have to compute the final 8-bit values from these int32
1741 // accumulators, and advance to the next 4x2 block. We intertwine
1742 // these two aspects whenever possible for optimal pipelining, both
1743 // at the data flow level (prefetch data for next block as early as
1744 // possible) and instruction pipelining level (some of the next-block
1745 // work can dual-issue with some of the final work on the current
1746 // block).
1747
1748 // q6-q9 now contain 4 x 32b
1749 "vpadd.i32 d0, d12, d13\n"
1750 "vpadd.i32 d1, d14, d15\n"
1751 "vpadd.i32 d2, d16, d17\n"
1752 "vpadd.i32 d3, d18, d19\n"
1753
1754 // d0-d4 each contain 2 x 32b accumulators.
1755 // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
1756 // of destination, (Four 'd' registers total)
1757 "vpadd.i32 d28, d0, d1\n"
1758 "vpadd.i32 d29, d2, d3\n"
1759
1760 // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
1761
1762 // Logic to advance to the next block in preparation for the next
1763 // iteration of the main loop. For now, we only want to compute
1764 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1765 // not yet ready to update the values of row and col, as we still need
1766 // the current values for the rest of the work on the current block.
1767
1768 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1769 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1770 "cmp r1, r3\n" // Have we finished the last row?
1771
1772 "bge 4f\n" // If finished last row, go to 4
1773 // Not finished last row: then advance to next row.
1774 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1775 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1776 "add r4, r4, r1, lsl #2\n"
1777 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1778 "b 5f\n"
1779 "4:\n" // Finished last row...
1780 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1781 // Go back to first row
1782 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1783
1784 // Now we need to advance to the next column. If we already
1785 // finished the last column, then in principle we are done, however
1786 // we can't just return here, as we need to allow the end work of the
1787 // current block to complete. The good news is that at this point it
1788 // doesn't matter what data we load for the next column, since
1789 // we will exit from the main loop below before actually storing
1790 // anything computed from that data.
1791
1792 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1793 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1794 "cmp r8, r4\n" // Have we finished the last column?
1795 "bge 5f\n" // If yes, just carry on without updating the column pointer.
1796 // Not finished last column: then advance to next column.
1797 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1798 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1799 "add r10, r10, r1, lsl #1\n"
1800 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1801 "5:\n"
1802
1803 // Set the LHS and RHS data pointers to the start of the columns just
1804 // computed.
1805 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1806 "mov %[lhs_ptr], r4\n"
1807 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1808 "mov %[rhs_ptr], r5\n"
1809
1810 // Now we load: bias data, LHS sums data, RHS sums data.
1811
1812 // First, load the base pointers from the params.
1813 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1814 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1815
1816 // Offset these base pointers as needed given the current row, col.
1817 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1818 "add r5, r1, r8, lsl #2\n"
1819
1820 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1821 "it ne\n"
1822 "movne r1, r5\n"
1823
1824 // Load 4 bias values.
1825 "vld1.32 {d24, d25}, [r1]\n"
1826
1827 // Now that we know what LHS and RHS data the next iteration of the
1828 // main loop will need to load, we start loading the first 32 bytes of
1829 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1830 // in the rest of the work on the current block.
1831 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1832 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1833 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1834 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1835 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1836 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1837 // Skip the other column and advance the pointer.
1838 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1839 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1840
1841 // Add to the bias values the product
1842 // (depth * lhs_zero_point * rhs_zero_point),
1843 // See the term NZ1Z2 in equation (7) in
1844 // https://arxiv.org/pdf/1712.05877.pdf
1845 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1846 "vdup.32 q9, r3\n"
1847 "vadd.i32 q12, q12, q9\n"
1848
1849 // Perform the bias-addition (per the above, we have just folded into
1850 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1851 "vadd.i32 q14, q14, q12\n"
1852
1853 // LHS/RHS zero points
1854 // Has RHS sums
1855 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1856 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1857 "beq 401f\n"
1858 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1859 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1860 // Offset by current col * number of bytes per value
1861 "add r3, r3, r4, lsl #2\n"
1862 "vld1.32 { d12 }, [r3]\n"
1863 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1864 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
1865 // Subtract rhs_sums * lhs_zero_point, per
1866 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1867 "vmls.i32 q14, q10, d12[0]\n"
1868 "401:\n"
1869
1870 // Has LHS sums
1871 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1872 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1873 "beq 402f\n"
1874 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1875 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1876 // Offset by current row * number of bytes per value
1877 "add r2, r2, r4, lsl #2\n"
1878 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1879
1880 // Load 4 lhs_sums values.
1881 "vld1.32 {d22, d23}, [r2]\n"
1882 "vdup.32 d13, r5\n" // rhs_zero_point
1883
1884 // Compute lhs_sums * rhs_zero_point.
1885 "vmul.i32 q11, q11, d13[1]\n"
1886 // Subtract lhs_sums * rhs_zero_point, per
1887 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1888 "vsub.s32 q14, q14, q11\n"
1889
1890 // If the destination is int32, it means the user asks for the raw
1891 // accumulators, no need for us to downquantize the value.
1892 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1893 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1894 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1895
1896 "402:\n"
1897
1898 // At this point we have computed the final int32 values. Now we
1899 // start down-quantizing them to obtain the final 8bit values from them.
1900
1901 // As part of this down-quantization, our int32 values will be
1902 // multiplied by a multiplier that has a fixed-point component and an
1903 // exponent component.
1904
1905 //Load the exponent part of the multiplier.
1906 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1907 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1908 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1909 "add r5, r1, r4, lsl #2\n"
1910 "it ne\n"
1911 "movne r1, r5\n"
1912
1913 "vld1.32 {q10}, [r1]\n"
1914
1915 RUY_MAKE_ZERO(q8)
1916 "vmax.s32 q12, q10, q8\n"
1917
1918 "vshl.s32 q14, q14, q12\n"
1919
1920 "vmin.s32 q12, q10, q8\n"
1921
1922 // Load fixed point part of the multiplier
1923 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1924 // r6 has flags, r4 has row
1925 "add r5, r1, r4, lsl #2\n"
1926 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1927 "it ne\n"
1928 "movne r1, r5\n"
1929 "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
1930
1931 // Apply the fixed-point part of the multiplier.
1932 "vqrdmulh.s32 q14, q14, q10\n"
1933
1934 // We have some rounding division-by-power-of-two to do. This should
1935 // always use "round to nearest". We allow for some
1936 // freedom in how ties are broken, to strike a good compromise of
1937 // performance on given hardware vs. perfect agreement of results
1938 // across hardware.
1939 //
1940 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
1941 // defined tie-breaks to help performance. On NEON, this means that we
1942 // can just use the NEON rounding instructions, such as srshl. They
1943 // happen to be breaking ties upward.
1944 //
1945 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
1946 // break-ties-away-from zero, as described in Appendix B of
1947 // https://arxiv.org/pdf/1712.05877.pdf
1948 // When we wrote that, we thought that that would be better unbiased
1949 // than the NEON upwards tie-breaks, and we had observed some
1950 // improvement on some model. However, that is only more unbiased for
1951 // data centered at zero, which was likely the case in that model,
1952 // but is not always the case. If we wanted something more consistently
1953 // unbiased then we should try breaking ties toward-nearest-even.
1954 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
1955 // Fix up values to be right-shifted, so that the (round to nearest,
1956 // break ties upward) behavior of srshl applied to these fixed-up
1957 // values, produces the same result as the desired (round to nearest,
1958 // break ties away from zero) behavior on the original values.
1959 "vand q8, q14, q12\n"
1960 "vshr.s32 q8, q8, #31\n"
1961 "vqadd.s32 q14, q14, q8\n"
1962
1963 #endif
1964 // At this point we have reduced the problem of correctly implementing
1965 // rounding divide-by-power-of-two, to what the SRSHL instruction can
1966 // do.
1967 "vrshl.s32 q14, q14, q12\n"
1968
1969 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1970 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1971 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1972 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1973 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1974
1975 // Store uint8 values:
1976 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1977
1978 // Cast-and-saturate from int32 to int16
1979 // After this, all values for output are in d28.
1980 "vqmovn.s32 d28, q14\n"
1981
1982 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
1983 // current block, so we can start clearing these accumulators for the
1984 // next block (next iteration of the main loop).
1985 RUY_MAKE_ZERO(q6)
1986 RUY_MAKE_ZERO(q7)
1987 RUY_MAKE_ZERO(q8)
1988 RUY_MAKE_ZERO(q9)
1989 RUY_MAKE_ZERO(q10)
1990 RUY_MAKE_ZERO(q11)
1991 RUY_MAKE_ZERO(q12)
1992 RUY_MAKE_ZERO(q13)
1993 RUY_MAKE_ZERO(q15)
1994
1995 // Load the destination zero point into each of the 8 16-bit slots
1996 // in a q register.
1997 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1998 "vdup.16 q13, r4\n" // dst_zero_point
1999
2000 // Add the destination zero point
2001 "vadd.i16 q14, q14, q13\n"
2002
2003 // Cast-and-saturate from int16 to uint8
2004 "vqmovun.s16 d30, q14\n"
2005 // At this point, we only need 4 8-bit values in the lower half
2006 // of d30.
2007
2008
2009 // Load the clamp_min, clamp_max bounds
2010 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2011 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2012 "vdup.8 d28, r2\n" // clamp_min
2013 "vdup.8 d29, r3\n" // clamp_max
2014
2015 // Apply the clamp_min bound
2016 "vmax.u8 d30, d30, d28\n"
2017 // Apply the clamp_max bound
2018 "vmin.u8 d30, d30, d29\n"
2019
2020 // Compute how much of the 4x1 block of destination 8bit values that
2021 // we have computed, fit in the destination matrix. Typically, all of
2022 // it fits, but when the destination matrix shape is not a multiple
2023 // of 4x1, there are some 4x1 blocks along the boundaries that do
2024 // not fit entirely.
2025
2026 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2027 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2028 "sub r1, r1, r8\n"
2029
2030 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2031 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2032 "sub r2, r2, r4\n"
2033 "mov r3, #4\n"
2034 "mov r5, #2\n"
2035 "cmp r1, #4\n"
2036 // Compute r1 = how many rows of the 4x1 block fit
2037 "it gt\n"
2038 "movgt r1, r3\n"
2039
2040 // Test if r1==4, i.e. if all of the 4x1 block fits.
2041 "cmp r1, r3\n"
2042
2043 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2044 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2045 // Yes, all of the 4x1 block fits, go to fast path.
2046 "beq 30f\n"
2047 // Not all of the 4x1 block fits.
2048 // Store to dst_tmp_buf
2049 // Set r3 address to write to dst_tmp_buf.
2050 "mov r3, %[dst_tmp_buf]\n"
2051 "vst1.8 {d30}, [r3]\n"
2052
2053 // Slow loop copying from dst_tmp_buf to dst.
2054 "50:\n"
2055 "mov r8, #0\n"
2056 "51:\n"
2057 "ldrb r10, [r3, r8]\n"
2058 "strb r10, [r4, r8]\n"
2059 "add r8, r8, #1\n"
2060 "cmp r8, r1\n"
2061 "blt 51b\n"
2062 "b 31f\n"
2063 "30:\n"
2064 // Yes, all of the 4x1 block fits.
2065 // r3 address, r5 stride
2066 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2067 "mov r4, r3\n"
2068 "mov r6, #1\n"
2069
2070 "vst1.8 {d30[0]}, [r3], r6\n"
2071 "vst1.8 {d30[1]}, [r3], r6\n"
2072 "vst1.8 {d30[2]}, [r3], r6\n"
2073 "vst1.8 {d30[3]}, [r3], r6\n"
2074 "31:\n"
2075
2076 // Load dst_ptr, increment, and write back.
2077 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2078 "add r4, r4, #4\n"
2079 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2080
2081 RUY_MAKE_ZERO(q13)
2082 RUY_MAKE_ZERO(q14)
2083 RUY_MAKE_ZERO(q15)
2084
2085 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2086
2087 // Store int8 values:
2088 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2089
2090 // Cast-and-saturate from int32 to int16
2091 // After this, all values for output are in d28.
2092 "vqmovn.s32 d28, q14\n"
2093
2094 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2095 // current block, so we can start clearing these accumulators for the
2096 // next block (next iteration of the main loop).
2097 RUY_MAKE_ZERO(q6)
2098 RUY_MAKE_ZERO(q7)
2099 RUY_MAKE_ZERO(q8)
2100 RUY_MAKE_ZERO(q9)
2101 RUY_MAKE_ZERO(q10)
2102 RUY_MAKE_ZERO(q11)
2103 RUY_MAKE_ZERO(q12)
2104 RUY_MAKE_ZERO(q13)
2105 RUY_MAKE_ZERO(q15)
2106
2107 // Load the destination zero point into each of the 8 16-bit slots
2108 // in a q register.
2109 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2110 "vdup.16 q13, r4\n" // dst_zero_point
2111
2112 // Add the destination zero point
2113 "vadd.i16 q14, q14, q13\n"
2114
2115 // Cast-and-saturate from int16 to int8
2116 "vqmovn.s16 d30, q14\n"
2117 // At this point, we only need 4 8-bit values in the lower half
2118 // of d30.
2119
2120 // Load the clamp_min, clamp_max bounds
2121 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2122 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2123 "vdup.8 d28, r2\n" // clamp_min
2124 "vdup.8 d29, r3\n" // clamp_max
2125
2126 // Apply the clamp_min bound
2127 "vmax.s8 d30, d30, d28\n"
2128 // Apply the clamp_max bound
2129 "vmin.s8 d30, d30, d29\n"
2130
2131 // Compute how much of the 4x1 block of destination 8bit values that
2132 // we have computed, fit in the destination matrix. Typically, all of
2133 // it fits, but when the destination matrix shape is not a multiple
2134 // of 4x2, there are some 4x2 blocks along the boundaries that do
2135 // not fit entirely.
2136
2137 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2138 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2139 "sub r1, r1, r8\n"
2140
2141 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2142 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2143 "sub r2, r2, r4\n"
2144 "mov r3, #4\n"
2145 "mov r5, #2\n"
2146 "cmp r1, #4\n"
2147 // Compute r1 = how many rows of the 4x2 block fit
2148 "it gt\n"
2149 "movgt r1, r3\n"
2150
2151 // Test if r1==4 i.e. if all of the 4x1 block fits.
2152 "cmp r1, r3\n"
2153
2154 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2155 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2156 // Yes, all of the 4x2 block fits, go to fast path.
2157 "beq 30f\n"
2158 // Not all of the 4x2 block fits.
2159 // Store to dst_tmp_buf
2160 // Set r3 address to write to dst_tmp_buf.
2161 "mov r3, %[dst_tmp_buf]\n"
2162 "vst1.8 {d30}, [r3]\n"
2163
2164 // Slow loop copying from dst_tmp_buf to dst.
2165 "50:\n"
2166 "mov r8, #0\n"
2167 "51:\n"
2168 "ldrb r10, [r3, r8]\n"
2169 "strb r10, [r4, r8]\n"
2170 "add r8, r8, #1\n"
2171 "cmp r8, r1\n"
2172 "blt 51b\n"
2173 "b 31f\n"
2174 "30:\n"
2175 // Yes, all of the 4x1 block fits.
2176 // r3 address, r5 stride
2177 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2178 "mov r4, r3\n"
2179 "mov r6, #1\n"
2180
2181 "vst1.8 {d30[0]}, [r3], r6\n"
2182 "vst1.8 {d30[1]}, [r3], r6\n"
2183 "vst1.8 {d30[2]}, [r3], r6\n"
2184 "vst1.8 {d30[3]}, [r3], r6\n"
2185 "31:\n"
2186
2187 // Load dst_ptr, increment, and write back.
2188 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2189 "add r4, r4, #4\n"
2190 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2191
2192 RUY_MAKE_ZERO(q13)
2193 RUY_MAKE_ZERO(q14)
2194 RUY_MAKE_ZERO(q15)
2195
2196 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2197
2198 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2199
2200 // Load the destination zero point into each of the 4 32-bit slots
2201 // in a q register.
2202 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2203 "vdup.32 q13, r4\n" // dst_zero_point
2204 // Add the destination zero point
2205 "vadd.s32 q14, q14, q13\n"
2206 //"vadd.s32 q15, q15, q13\n"
2207
2208 // Cast-and-saturate from int32 to int16
2209 // After this, all values for output are in d28.
2210 "vqmovn.s32 d28, q14\n"
2211
2212 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2213 // so we can start clearing these accumulators for the next block
2214 // (next iteration of the main loop).
2215 RUY_MAKE_ZERO(q6)
2216 RUY_MAKE_ZERO(q7)
2217 RUY_MAKE_ZERO(q8)
2218 RUY_MAKE_ZERO(q9)
2219 RUY_MAKE_ZERO(q10)
2220 RUY_MAKE_ZERO(q11)
2221 RUY_MAKE_ZERO(q15)
2222
2223 // Load the clamp_min, clamp_max bounds
2224 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2225 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2226 "vdup.16 d24, r2\n" // clamp_min
2227 "vdup.16 d26, r3\n" // clamp_max
2228
2229 // Apply the clamp_min bound
2230 "vmax.s16 d28, d28, d24\n"
2231 // Apply the clamp_max bound
2232 "vmin.s16 d28, d28, d26\n"
2233
2234 RUY_MAKE_ZERO(q12)
2235 RUY_MAKE_ZERO(q13)
2236
2237 // Compute how much of the 4x1 block of destination 16-bit values that
2238 // we have computed, fit in the destination matrix. Typically, all of
2239 // it fits, but when the destination matrix shape is not a multiple
2240 // of 4x1, there are some 4x1 blocks along the boundaries that do
2241 // not fit entirely.
2242
2243 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2244 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2245 "sub r1, r1, r8\n"
2246
2247 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2248 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2249 "sub r2, r2, r4\n"
2250 "mov r3, #4\n"
2251 "mov r5, #2\n"
2252 "cmp r1, #4\n"
2253 // Compute r1 = how many rows of the 4x1 block fit
2254 "it gt\n"
2255 "movgt r1, r3\n"
2256
2257 // Test if r1==4, i.e. if all of the 4x1 block fits.
2258 "cmp r1, r3\n"
2259
2260 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2261 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2262 // Yes, all of the 4x1 block fits, go to fast path.
2263 "beq 30f\n"
2264 // Not all of the 4x1 block fits.
2265 // Store to dst_tmp_buf
2266 // Set r3 address to write to dst_tmp_buf.
2267 "mov r3, %[dst_tmp_buf]\n"
2268 "vst1.16 {d28}, [r3]\n"
2269
2270 // Slow loop copying from dst_tmp_buf to dst.
2271 "50:\n"
2272 "mov r8, #0\n"
2273 "51:\n"
2274 // Shift of offset register for half-word loads not allowed in A32,
2275 // so we shift, load/store, then shift back r8.
2276 "lsl r8, r8, #1\n"
2277 "ldrh r10, [r3, r8]\n"
2278 "strh r10, [r4, r8]\n"
2279 "lsr r8, r8, #1\n"
2280 "add r8, r8, #1\n"
2281 "cmp r8, r1\n"
2282 "blt 51b\n"
2283 "b 31f\n"
2284 "30:\n"
2285 // Yes, all of the 4x1 block fits.
2286 // r3 address, r5 stride
2287 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2288 "mov r4, r3\n"
2289 "mov r6, #2\n"
2290
2291 "vst1.16 {d28[0]}, [r3], r6\n"
2292 "vst1.16 {d28[1]}, [r3], r6\n"
2293 "vst1.16 {d28[2]}, [r3], r6\n"
2294 "vst1.16 {d28[3]}, [r3], r6\n"
2295 "31:\n"
2296
2297 // Load dst_ptr, increment, and write back.
2298 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2299 "add r4, r4, #8\n"
2300 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2301
2302 RUY_MAKE_ZERO(q14)
2303
2304 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2305
2306 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2307
2308 // Since the store type is the same as the accum type, no need for
2309 // downcast. There's also no need for clamp by min/max.
2310
2311 // At this point, v20 -- v31 aren't used anymore for the current block,
2312 // so we can start clearing these accumulators for the next block
2313 // (next iteration of the main loop).
2314 // Clear accumulators.
2315 RUY_MAKE_ZERO(q6)
2316 RUY_MAKE_ZERO(q7)
2317 RUY_MAKE_ZERO(q8)
2318 RUY_MAKE_ZERO(q9)
2319 RUY_MAKE_ZERO(q10)
2320 RUY_MAKE_ZERO(q11)
2321 RUY_MAKE_ZERO(q12)
2322 RUY_MAKE_ZERO(q13)
2323
2324 // Compute how much of the 4x1 block of destination 32 bit values that
2325 // we have computed, fit in the destination matrix. Typically, all of
2326 // it fits, but when the destination matrix shape is not a multiple
2327 // of 4x2, there are some 4x4 blocks along the boundaries that do
2328 // not fit entirely.
2329
2330 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2331 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2332 "sub r1, r1, r8\n"
2333
2334 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2335 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2336 "sub r2, r2, r4\n"
2337 "mov r3, #4\n"
2338 "mov r5, #2\n"
2339 "cmp r1, #4\n"
2340 // Compute r1 = how many rows of the 4x2 block fit
2341 "it gt\n"
2342 "movgt r1, r3\n"
2343
2344 // Test if r1==4, i.e. if all of the 4x1 block fits.
2345 "cmp r1, r3\n"
2346
2347 // Yes, all of the 4x1 block fits, go to fast path.
2348 "beq 30f\n"
2349 // Not all of the 4x1 block fits.
2350 // Set (r3 address, r4 stride) to write to dst_tmp_buf
2351 "mov r3, %[dst_tmp_buf]\n"
2352 "mov r4, #16\n"
2353 "b 31f\n"
2354
2355 "30:\n"
2356 // Yes, all of the 4x1 block fits.
2357 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2358 // r3 address, r4 stride
2359 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2360 "mov r4, r5\n"
2361
2362 "31:\n"
2363
2364 "vst1.32 {d28, d29}, [r3]\n"
2365
2366 // If all of the 4x1 block fits, we just finished writing it to the
2367 // destination, so we skip the next part.
2368 "beq 41f\n"
2369 // Not all of the 4x1 block fits in the destination matrix. We just
2370 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
2371 // it to copy into the destination matrix the part that fits.
2372 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2373 "mov r3, %[dst_tmp_buf]\n"
2374 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2375 "50:\n"
2376 "mov r5, #0\n"
2377 "51:\n"
2378 "ldr r10, [r3, r5, lsl #2]\n"
2379 "str r10, [r4, r5, lsl #2]\n"
2380 "add r5, r5, #1\n"
2381 "cmp r5, r1\n"
2382 "blt 51b\n"
2383
2384 "41:\n"
2385 // Load dst_ptr, increment, and write back.
2386 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2387 "add r4, r4, #16\n"
2388 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2389
2390 RUY_MAKE_ZERO(q10)
2391 RUY_MAKE_ZERO(q11)
2392
2393 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2394
2395 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2396
2397 // Reload some params --- we had used x5 -- x7 for a few other things
2398 // since the last time we had loaded them.
2399 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2400 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2401 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2402
2403 // Move to the next block of the destination matrix, for the next iter
2404 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
2405 // been updated earlier.
2406 // Have we reached the end row?
2407 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2408 "cmp r8, r3\n"
2409
2410 "beq 20f\n" // yes, end row.
2411 // Not end row. Move to the next row.
2412 "add r8, r8, #4\n"
2413 // Store new value of row
2414 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2415
2416 "b 21f\n"
2417 "20:\n"
2418 // Was already at end row.
2419 // Move back to first row.
2420 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2421 // Move to the next column.
2422 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2423 "add r4, r4, #2\n"
2424 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2425
2426 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2427 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2428 // Increment dst_col_ptr by dst_stride (i.e. 1 column)
2429 "add r1, r1, r8\n"
2430 // Store dst_col_ptr
2431 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2432 // Store dst_ptr
2433 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2434 "21:\n"
2435
2436 // Main loop exit condition: have we hit the end column?
2437 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2438 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2439 "cmp r8, r4\n"
2440
2441 // w1 is the number of levels of depth that we have already loaded
2442 // LHS and RHS data for. Corresponding to the initial ld1 instructions
2443 // above, this is currently 16.
2444 "mov r1, #16\n"
2445
2446 "ble 1b\n"
2447
2448 // Restore stack pointer.
2449 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
2450
2451 // clang-format on
2452
2453 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
2454 : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf)
2455 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
2456 // Clobber list must specify q registers (and not their constituent
2457 // d registers). There is a (currently unexplained) slowdown if
2458 // d registers are listed in the clobbers list.
2459 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
2460 "q9", "q10", "q12", "q13", "q14", "q15");
2461 }
2462
2463 #undef RUY_OFFSET_BIAS
2464 #undef RUY_OFFSET_LHS_SUMS
2465 #undef RUY_OFFSET_RHS_SUMS
2466 #undef RUY_OFFSET_LHS_BASE_PTR
2467 #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
2468 #undef RUY_OFFSET_MULTIPLIER_EXPONENT
2469 #undef RUY_OFFSET_RHS_BASE_PTR
2470 #undef RUY_OFFSET_DST_BASE_PTR
2471 #undef RUY_OFFSET_LHS_ZERO_POINT
2472 #undef RUY_OFFSET_RHS_ZERO_POINT
2473 #undef RUY_OFFSET_DST_ZERO_POINT
2474 #undef RUY_OFFSET_PROD_ZP_DEPTH
2475 #undef RUY_OFFSET_START_ROW
2476 #undef RUY_OFFSET_START_COL
2477 #undef RUY_OFFSET_LAST_ROW
2478 #undef RUY_OFFSET_LAST_COL
2479 #undef RUY_OFFSET_DST_ROWS
2480 #undef RUY_OFFSET_DST_COLS
2481 #undef RUY_OFFSET_LHS_STRIDE
2482 #undef RUY_OFFSET_RHS_STRIDE
2483 #undef RUY_OFFSET_DST_STRIDE
2484 #undef RUY_OFFSET_DEPTH
2485 #undef RUY_OFFSET_CLAMP_MIN
2486 #undef RUY_OFFSET_CLAMP_MAX
2487 #undef RUY_OFFSET_FLAGS
2488 #undef RUY_OFFSET_DST_TYPE_ID
2489
2490 #undef RUY_STACK_OFFSET_SIZE
2491 #undef RUY_STACK_OFFSET_DST_COL_PTR
2492 #undef RUY_STACK_OFFSET_DST_PTR
2493 #undef RUY_STACK_OFFSET_ROW
2494 #undef RUY_STACK_OFFSET_COL
2495 #undef RUY_STACK_OFFSET_LHS_COL_PTR
2496 #undef RUY_STACK_OFFSET_RHS_COL_PTR
2497
2498 #endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM)
2499 } // namespace ruy
2500