1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "tensorflow/lite/experimental/ruy/common.h"
19 #include "tensorflow/lite/experimental/ruy/kernel.h"
20 #include "tensorflow/lite/experimental/ruy/opt_set.h"
21 #include "tensorflow/lite/experimental/ruy/platform.h"
22 #include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h"
23
24 namespace ruy {
25
26 #if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
27
28 #define RUY_ASM_LABEL_STORE_UINT8 91
29 #define RUY_ASM_LABEL_STORE_INT8 92
30 #define RUY_ASM_LABEL_STORE_INT16 93
31 #define RUY_ASM_LABEL_STORE_INT32 94
32 #define RUY_ASM_LABEL_AFTER_STORE 99
33
34 #define RUY_OFFSET_BIAS 0
35 #define RUY_OFFSET_LHS_SUMS 8
36 #define RUY_OFFSET_RHS_SUMS 16
37 #define RUY_OFFSET_LHS_BASE_PTR 24
38 #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
39 #define RUY_OFFSET_MULTIPLIER_EXPONENT 40
40 #define RUY_OFFSET_RHS_BASE_PTR 48
41 #define RUY_OFFSET_DST_BASE_PTR 56
42 #define RUY_OFFSET_LHS_ZERO_POINT 64
43 #define RUY_OFFSET_RHS_ZERO_POINT 68
44 #define RUY_OFFSET_DST_ZERO_POINT 72
45 #define RUY_OFFSET_PROD_ZP_DEPTH 76
46 #define RUY_OFFSET_START_ROW 80
47 #define RUY_OFFSET_START_COL 84
48 #define RUY_OFFSET_LAST_ROW 88
49 #define RUY_OFFSET_LAST_COL 92
50 #define RUY_OFFSET_DST_ROWS 96
51 #define RUY_OFFSET_DST_COLS 100
52 #define RUY_OFFSET_LHS_STRIDE 104
53 #define RUY_OFFSET_RHS_STRIDE 108
54 #define RUY_OFFSET_DST_STRIDE 112
55 #define RUY_OFFSET_DEPTH 116
56 #define RUY_OFFSET_CLAMP_MIN 120
57 #define RUY_OFFSET_CLAMP_MAX 124
58 #define RUY_OFFSET_FLAGS 128
59
60 template <typename Params>
CheckOffsetsInKernelParams8bit(const Params &)61 void CheckOffsetsInKernelParams8bit(const Params&) {
62 static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
63 "");
64 static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
65 "");
66 static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
67 "");
68 static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
69 "");
70 static_assert(offsetof(Params, multiplier_fixedpoint) ==
71 RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
72 "");
73 static_assert(
74 offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
75 "");
76 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
77 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
78 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
79 static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
80 static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
81 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
82 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
83 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
84 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
85 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
86 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
87 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
88 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
89 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
90 }
91
92 // Fast-int8-trick kernel, similar to this production gemmlowp kernel:
93 // NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
94 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
95 //
96 // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
97 // since these are 64-bit, out-of-order and without dotprod support.
Kernel8bitNeonOutOfOrder(const KernelParams8bit<4,4> & params)98 void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
99 profiler::ScopeLabel label(
100 "Kernel (kNeon, optimized for out-of-order cores)");
101
102 CheckOffsetsInKernelParams8bit(params);
103
104 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
105 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
106 const std::int8_t* lhs_ptr = lhs_col_ptr;
107 const std::int8_t* rhs_ptr = rhs_col_ptr;
108 void* dst_col_ptr = params.dst_base_ptr;
109 void* dst_ptr = dst_col_ptr;
110 int row = params.start_row;
111 int col = params.start_col;
112
113 // The asm kernel below has the following NEON register allocation:
114 //
115 // v16 -- v31 are int32 accumulators.
116 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
117 // v4 -- v7 from RHS:
118 //
119 // int8 RHS 16x4 block
120 // /-----------------------------------------\
121 // |v4.b[0] ... v7.b[0] |
122 // | ... ... |
123 // |v4.b[15] ... v7.b[15] |
124 // \-----------------------------------------/
125 // int8 LHS 4x16 block
126 // /---------------------\ /-----------------------------------------\
127 // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
128 // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
129 // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
130 // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
131 // \---------------------/ \-----------------------------------------/
132 // int32 accumulators 4x4 block
133 //
134 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
135 // optimization for this kernel.
136 asm volatile(
137 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
138
139 // clang-format off
140
141 // Load some parameters into registers.
142 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
143 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
144 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
145 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
146 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
147 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
148 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
149 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
150
151 // Load the first 64 bytes of LHS and RHS data.
152 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
153 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
154 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
155 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
156 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
157 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
158 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
159 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
160
161 // Clear accumulators.
162 RUY_MAKE_ZERO(v16)
163 RUY_MAKE_ZERO(v17)
164 RUY_MAKE_ZERO(v18)
165 RUY_MAKE_ZERO(v19)
166 RUY_MAKE_ZERO(v20)
167 RUY_MAKE_ZERO(v21)
168 RUY_MAKE_ZERO(v22)
169 RUY_MAKE_ZERO(v23)
170 RUY_MAKE_ZERO(v24)
171 RUY_MAKE_ZERO(v25)
172 RUY_MAKE_ZERO(v26)
173 RUY_MAKE_ZERO(v27)
174 RUY_MAKE_ZERO(v28)
175 RUY_MAKE_ZERO(v29)
176 RUY_MAKE_ZERO(v30)
177 RUY_MAKE_ZERO(v31)
178
179 // w1 is the number of levels of depth that we have already loaded
180 // LHS and RHS data for. Corresponding to the initial ld1 instructions
181 // above, this is currently 16.
182 "mov w1, #16\n"
183
184 // Perform the first few multiply-adds on the data that we have already
185 // loaded.
186 "smull v8.8h, v0.8b, v4.8b\n"
187 "smull v9.8h, v1.8b, v4.8b\n"
188 "smull v10.8h, v2.8b, v4.8b\n"
189 "smull v11.8h, v3.8b, v4.8b\n"
190 "smull v12.8h, v0.8b, v5.8b\n"
191 "smull v13.8h, v1.8b, v5.8b\n"
192 "smull v14.8h, v2.8b, v5.8b\n"
193 "smull v15.8h, v3.8b, v5.8b\n"
194
195 // Multiply-accumulate second-half, again into the same
196 // 16bit local accumulator registers. This is where we
197 // take advantage of having int8 instead of uint8 and therefore
198 // being able to accumulate two products into int16.
199 "smlal2 v8.8h, v0.16b, v4.16b\n"
200 "smlal2 v9.8h, v1.16b, v4.16b\n"
201 "smlal2 v10.8h, v2.16b, v4.16b\n"
202 "smlal2 v11.8h, v3.16b, v4.16b\n"
203 "smlal2 v12.8h, v0.16b, v5.16b\n"
204 "smlal2 v13.8h, v1.16b, v5.16b\n"
205 "smlal2 v14.8h, v2.16b, v5.16b\n"
206 "smlal2 v15.8h, v3.16b, v5.16b\n"
207
208
209 // Main loop of the whole GEMM, over rows and columns of the
210 // destination matrix.
211 "1:\n"
212
213 // Reminder - w1 is how many levels of depth we have already loaded
214 // data for, w12 is the total depth.
215 "cmp w1, w12\n"
216 "beq 79f\n"
217
218 "2:\n"
219
220 // Some multiplications and 16-bit accumulation were already done above,
221 // so we start right away in the middle.
222 "sadalp v16.4s, v8.8h\n"
223 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
224 "smull v8.8h, v0.8b, v6.8b\n"
225 "sadalp v17.4s, v9.8h\n"
226 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
227 "smull v9.8h, v1.8b, v6.8b\n"
228 "sadalp v18.4s, v10.8h\n"
229 "smull v10.8h, v2.8b, v6.8b\n"
230 "sadalp v19.4s, v11.8h\n"
231 "smull v11.8h, v3.8b, v6.8b\n"
232 "sadalp v20.4s, v12.8h\n"
233 "smull v12.8h, v0.8b, v7.8b\n"
234 "sadalp v21.4s, v13.8h\n"
235 "smull v13.8h, v1.8b, v7.8b\n"
236 "sadalp v22.4s, v14.8h\n"
237 "smull v14.8h, v2.8b, v7.8b\n"
238 "sadalp v23.4s, v15.8h\n"
239 "smull v15.8h, v3.8b, v7.8b\n"
240
241 // Multiply-accumulate second-half, again into the same
242 // 16bit local accumulator registers. This is where we
243 // take advantage of having int8 instead of uint8 and therefore
244 // being able to accumulate two products into int16.
245 "smlal2 v8.8h, v0.16b, v6.16b\n"
246 "smlal2 v9.8h, v1.16b, v6.16b\n"
247 "smlal2 v10.8h, v2.16b, v6.16b\n"
248 "smlal2 v11.8h, v3.16b, v6.16b\n"
249
250 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
251
252 "smlal2 v12.8h, v0.16b, v7.16b\n"
253 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
254 "smlal2 v13.8h, v1.16b, v7.16b\n"
255 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
256 "smlal2 v14.8h, v2.16b, v7.16b\n"
257 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
258 "smlal2 v15.8h, v3.16b, v7.16b\n"
259 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
260
261 "sadalp v24.4s, v8.8h\n"
262 "smull v8.8h, v0.8b, v4.8b\n"
263 "sadalp v25.4s, v9.8h\n"
264 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
265 "smull v9.8h, v1.8b, v4.8b\n"
266 "sadalp v26.4s, v10.8h\n"
267 "smull v10.8h, v2.8b, v4.8b\n"
268 "sadalp v27.4s, v11.8h\n"
269 "smull v11.8h, v3.8b, v4.8b\n"
270 "sadalp v28.4s, v12.8h\n"
271 "smull v12.8h, v0.8b, v5.8b\n"
272 "sadalp v29.4s, v13.8h\n"
273 "smull v13.8h, v1.8b, v5.8b\n"
274 "sadalp v30.4s, v14.8h\n"
275 "smull v14.8h, v2.8b, v5.8b\n"
276 "sadalp v31.4s, v15.8h\n"
277 "smull v15.8h, v3.8b, v5.8b\n"
278
279 // Multiply-accumulate second-half, again into the same
280 // 16bit local accumulator registers. This is where we
281 // take advantage of having int8 instead of uint8 and therefore
282 // being able to accumulate two products into int16.
283 "smlal2 v8.8h, v0.16b, v4.16b\n"
284 "smlal2 v9.8h, v1.16b, v4.16b\n"
285 "smlal2 v10.8h, v2.16b, v4.16b\n"
286 "smlal2 v11.8h, v3.16b, v4.16b\n"
287
288 "smlal2 v12.8h, v0.16b, v5.16b\n"
289 "smlal2 v13.8h, v1.16b, v5.16b\n"
290 "smlal2 v14.8h, v2.16b, v5.16b\n"
291 "smlal2 v15.8h, v3.16b, v5.16b\n"
292
293
294
295 // Each iteration of this loop advances by 16 levels of depth.
296 "add w1, w1, #16\n"
297
298 // Loop termination condition
299 "cmp w1, w12\n"
300
301 "blt 2b\n"
302
303 "79:\n"
304
305 "sadalp v16.4s, v8.8h\n"
306 "smull v8.8h, v0.8b, v6.8b\n"
307 "sadalp v17.4s, v9.8h\n"
308 "smull v9.8h, v1.8b, v6.8b\n"
309 "sadalp v18.4s, v10.8h\n"
310 "smull v10.8h, v2.8b, v6.8b\n"
311 "sadalp v19.4s, v11.8h\n"
312 "smull v11.8h, v3.8b, v6.8b\n"
313 "sadalp v20.4s, v12.8h\n"
314 "smull v12.8h, v0.8b, v7.8b\n"
315 "sadalp v21.4s, v13.8h\n"
316 "smull v13.8h, v1.8b, v7.8b\n"
317 "sadalp v22.4s, v14.8h\n"
318 "smull v14.8h, v2.8b, v7.8b\n"
319 "sadalp v23.4s, v15.8h\n"
320 "smull v15.8h, v3.8b, v7.8b\n"
321
322 // Multiply-accumulate second-half, again into the same
323 // 16bit local accumulator registers. This is where we
324 // take advantage of having int8 instead of uint8 and therefore
325 // being able to accumulate two products into int16.
326 "smlal2 v8.8h, v0.16b, v6.16b\n"
327 "smlal2 v9.8h, v1.16b, v6.16b\n"
328 "smlal2 v10.8h, v2.16b, v6.16b\n"
329 "smlal2 v11.8h, v3.16b, v6.16b\n"
330
331 "smlal2 v12.8h, v0.16b, v7.16b\n"
332 "smlal2 v13.8h, v1.16b, v7.16b\n"
333 "smlal2 v14.8h, v2.16b, v7.16b\n"
334 "smlal2 v15.8h, v3.16b, v7.16b\n"
335
336 "sadalp v24.4s, v8.8h\n"
337 "sadalp v25.4s, v9.8h\n"
338 "sadalp v26.4s, v10.8h\n"
339 "sadalp v27.4s, v11.8h\n"
340 "sadalp v28.4s, v12.8h\n"
341 "sadalp v29.4s, v13.8h\n"
342 "sadalp v30.4s, v14.8h\n"
343 "sadalp v31.4s, v15.8h\n"
344
345 // End of accumulation. The registers v16 -- v31 contain the final
346 // int32 accumulator values of the current 4x4 destination block.
347 // We now have to compute the final 8-bit values from these int32
348 // accumulators, and advance to the next 4x4 block. We intertwine
349 // these two aspects whenever possible for optimal pipelining, both
350 // at the data flow level (prefetch data for next block as early as
351 // possible) and instruction pipelining level (some of the next-block
352 // work can dual-issue with some of the final work on the current
353 // block).
354
355 // Reduce 32bit accumulators horizontally.
356 "addp v16.4s, v16.4s, v17.4s\n"
357 "addp v18.4s, v18.4s, v19.4s\n"
358 "addp v20.4s, v20.4s, v21.4s\n"
359 "addp v22.4s, v22.4s, v23.4s\n"
360 "addp v24.4s, v24.4s, v25.4s\n"
361 "addp v26.4s, v26.4s, v27.4s\n"
362 "addp v28.4s, v28.4s, v29.4s\n"
363 "addp v30.4s, v30.4s, v31.4s\n"
364
365 // Reduce 32bit accumulators horizontally, second pass
366 // (each pass adds pairwise. we need to add 4-wise).
367 "addp v16.4s, v16.4s, v18.4s\n"
368 "addp v17.4s, v20.4s, v22.4s\n"
369 "addp v18.4s, v24.4s, v26.4s\n"
370 "addp v19.4s, v28.4s, v30.4s\n"
371
372 // Logic to advance to the next block in preparation for the next
373 // iteration of the main loop. For now, we only want to compute
374 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
375 // not yet ready to update the values of row and col, as we still need
376 // the current values for the rest of the work on the current block.
377
378 "cmp %w[row], w7\n" // Have we finished the last row?
379 "bge 4f\n" // If finished last row, go to 4
380 // Not finished last row: then advance to next row.
381 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
382 "b 5f\n"
383 "4:\n" // Finished last row...
384 "mov %[lhs_col_ptr], x5\n" // Go back to first row
385 // Now we need to advance to the next column. If we already
386 // finished the last column, then in principle we are done, however
387 // we can't just return here, as we need to allow the end work of the
388 // current block to complete. The good news is that at this point it
389 // doesn't matter what data we load for the next column, since
390 // we will exit from the main loop below before actually storing
391 // anything computed from that data.
392 "cmp %w[col], w8\n" // Have we finished the last column?
393 "bge 5f\n" // If yes, just carry on without updating the column pointer.
394 // Not finished last column: then advance to next column.
395 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
396 "5:\n"
397
398 // Set the LHS and RHS data pointers to the start of the columns just
399 // computed.
400 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
401 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
402
403 // Load some parameters needed for the end work on current block.
404 RUY_MAKE_ZERO(v8)
405 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
406 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
407 "ins v13.h[4], w4\n" // dst_zero_point
408 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
409 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
410 "dup v9.4s, w3\n" // create prod_zp_depth_vec
411 "add x5, x4, %x[row], lsl #2\n"
412 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
413 "csel x4, x4, x5, eq\n"
414
415 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
416
417 // Now we load: bias data, LHS sums data, RHS sums data.
418
419 // First, load the base pointers from the params.
420 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
421
422 "add x5, x1, %x[row], lsl #2\n"
423 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
424 "csel x1, x1, x5, eq\n"
425
426 // Load 4 bias values.
427 "ld1 {v14.4s}, [x1]\n"
428
429 // Now that we know what LHS and RHS data the next iteration of the
430 // main loop will need to load, we start loading the first 32 bytes of
431 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
432 // in the rest of the work on the current block.
433 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
434 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
435 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
436 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
437 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
438 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
439 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
440 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
441
442 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
443 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
444 "add v14.4s, v14.4s, v9.4s\n"
445
446 // Perform the bias-addition (per the above, we have just folded into
447 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
448 "add v16.4s, v16.4s, v14.4s\n"
449 "add v17.4s, v17.4s, v14.4s\n"
450 "add v18.4s, v18.4s, v14.4s\n"
451 "add v19.4s, v19.4s, v14.4s\n"
452
453 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
454 "beq 401f\n"
455 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
456 "add x3, x3, %x[col], lsl #2\n"
457 "ld1 {v14.4s}, [x3]\n"
458 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
459 "dup v10.4s, w5\n" // create lhs_zero_point_vec
460 // Subtract rhs_sums * lhs_zero_point, per
461 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
462 "mls v16.4s, v10.4s, v14.s[0]\n"
463 "mls v17.4s, v10.4s, v14.s[1]\n"
464 "mls v18.4s, v10.4s, v14.s[2]\n"
465 "mls v19.4s, v10.4s, v14.s[3]\n"
466 "401:\n"
467
468 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
469 "beq 402f\n"
470 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
471 "add x2, x2, %x[row], lsl #2\n"
472 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
473 // Load 4 lhs_sums values.
474 "ld1 {v11.4s}, [x2]\n"
475 "ins v13.s[1], w5\n" // rhs_zero_point
476 // Compute lhs_sums * rhs_zero_point.
477 "mul v11.4s, v11.4s, v13.s[1]\n"
478 // Subtract lhs_sums * rhs_zero_point, per
479 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
480 "sub v16.4s, v16.4s, v11.4s\n"
481 "sub v17.4s, v17.4s, v11.4s\n"
482 "sub v18.4s, v18.4s, v11.4s\n"
483 "sub v19.4s, v19.4s, v11.4s\n"
484
485 // If the destination is int32, it means the user asks for the raw
486 // accumulators, no need for us to downquantize the value.
487 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
488 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
489
490 "402:\n"
491
492 // At this point we have computed the final int32 values. Now we
493 // start down-quantizing them to obtain the final 8bit values from them.
494
495 // As part of this down-quantization, our int32 values will be
496 // multiplied by a multiplier that has a fixed-point component and an
497 // exponent component.
498
499 //Load the exponent part of the multiplier.
500 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
501 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
502 "add x5, x1, %x[row], lsl #2\n"
503 "csel x1, x1, x5, eq\n"
504
505 "ld1 {v14.4s}, [x1]\n"
506
507 "smax v12.4s, v14.4s, v8.4s\n"
508
509 "sshl v16.4s, v16.4s, v12.4s\n"
510 "sshl v17.4s, v17.4s, v12.4s\n"
511 "sshl v18.4s, v18.4s, v12.4s\n"
512 "sshl v19.4s, v19.4s, v12.4s\n"
513
514 "smin v12.4s, v14.4s, v8.4s\n"
515
516 // Apply the fixed-point part of the multiplier.
517 "sqrdmulh v16.4s, v16.4s, v15.4s\n"
518 "sqrdmulh v17.4s, v17.4s, v15.4s\n"
519 "sqrdmulh v18.4s, v18.4s, v15.4s\n"
520 "sqrdmulh v19.4s, v19.4s, v15.4s\n"
521
522 // We have some rounding division-by-power-of-two to do. This should
523 // always use "round to nearest". We allow for some
524 // freedom in how ties are broken, to strike a good compromise of
525 // performance on given hardware vs. perfect agreement of results
526 // across hardware.
527 //
528 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
529 // defined tie-breaks to help performance. On NEON, this means that we
530 // can just use the NEON rounding instructions, such as srshl. They
531 // happen to be breaking ties upward.
532 //
533 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
534 // break-ties-away-from zero, as described in Appendix B of
535 // https://arxiv.org/pdf/1712.05877.pdf
536 // When we wrote that, we thought that that would be better unbiased
537 // than the NEON upwards tie-breaks, and we had observed some
538 // improvement on some model. However, that is only more unbiased for
539 // data centered at zero, which was likely the case in that model,
540 // but is not always the case. If we wanted something more consistently
541 // unbiased then we should try breaking ties toward-nearest-even.
542 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
543 // Fix up values to be right-shifted, so that the (round to nearest,
544 // break ties upward) behavior of srshl applied to these fixed-up
545 // values, produces the same result as the desired (round to nearest,
546 // break ties away from zero) behavior on the original values.
547 "and v8.16b, v16.16b, v12.16b\n"
548 "and v9.16b, v17.16b, v12.16b\n"
549 "and v14.16b, v18.16b, v12.16b\n"
550 "and v15.16b, v19.16b, v12.16b\n"
551 "sshr v8.4s, v8.4s, #31\n"
552 "sshr v9.4s, v9.4s, #31\n"
553 "sshr v14.4s, v14.4s, #31\n"
554 "sshr v15.4s, v15.4s, #31\n"
555 "sqadd v16.4s, v16.4s, v8.4s\n"
556 "sqadd v17.4s, v17.4s, v9.4s\n"
557 "sqadd v18.4s, v18.4s, v14.4s\n"
558 "sqadd v19.4s, v19.4s, v15.4s\n"
559 #endif
560 // At this point we have reduced the problem of correctly implementing
561 // rounding divide-by-power-of-two, to what the SRSHL instruction can
562 // do.
563 "srshl v16.4s, v16.4s, v12.4s\n"
564 "srshl v17.4s, v17.4s, v12.4s\n"
565 "srshl v18.4s, v18.4s, v12.4s\n"
566 "srshl v19.4s, v19.4s, v12.4s\n"
567
568 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
569 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
570 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
571 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
572
573 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
574
575 // Cast-and-saturate from int32 to int16
576 "sqxtn v16.4h, v16.4s\n"
577 "sqxtn2 v16.8h, v17.4s\n"
578 "sqxtn v17.4h, v18.4s\n"
579 "sqxtn2 v17.8h, v19.4s\n"
580
581 // At this point, v18 -- v31 aren't used anymore for the current block,
582 // so we can start clearing these accumulators for the next block
583 // (next iteration of the main loop).
584 RUY_MAKE_ZERO(v18)
585 RUY_MAKE_ZERO(v19)
586 RUY_MAKE_ZERO(v20)
587 RUY_MAKE_ZERO(v21)
588 RUY_MAKE_ZERO(v22)
589 RUY_MAKE_ZERO(v23)
590 RUY_MAKE_ZERO(v24)
591 RUY_MAKE_ZERO(v25)
592 RUY_MAKE_ZERO(v26)
593 RUY_MAKE_ZERO(v27)
594 RUY_MAKE_ZERO(v28)
595 RUY_MAKE_ZERO(v29)
596 RUY_MAKE_ZERO(v30)
597 RUY_MAKE_ZERO(v31)
598
599 // Add the destination zero point
600 "dup v14.8h, v13.h[4]\n"
601 "add v16.8h, v16.8h, v14.8h\n"
602 "add v17.8h, v17.8h, v14.8h\n"
603
604 // Cast-and-saturate from int16 to uint8
605 "sqxtun v16.8b, v16.8h\n"
606 "sqxtun2 v16.16b, v17.8h\n"
607
608 // Load the clamp_min, clamp_max bounds
609 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
610 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
611 "dup v14.16b, w2\n" // clamp_min
612 "dup v15.16b, w3\n" // clamp_max
613
614 // Apply the clamp_min bound
615 "umax v16.16b, v16.16b, v14.16b\n"
616 // Apply the clamp_max bound
617 "umin v16.16b, v16.16b, v15.16b\n"
618
619 // Compute how much of the 4x4 block of destination 8bit values that
620 // we have computed, fit in the destination matrix. Typically, all of
621 // it fits, but when the destination matrix shape is not a multiple
622 // of 4x4, there are some 4x4 blocks along the boundaries that do
623 // not fit entirely.
624 "sub w1, %w[dst_rows], %w[row]\n"
625 "sub w2, %w[dst_cols], %w[col]\n"
626 "mov w3, #4\n"
627 "cmp w1, #4\n"
628 // Compute w1 = how many rows of the 4x4 block fit
629 "csel w1, w1, w3, le\n"
630 "cmp w2, #4\n"
631 // Compute w2 = how many cols of the 4x4 block fit
632 "csel w2, w2, w3, le\n"
633
634 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
635 "cmp w1, w3\n"
636 "ccmp w2, w3, 0, eq\n"
637 "mov x4, %[dst_ptr]\n"
638 // Yes, all of the 4x4 block fits, go to fast path.
639 "beq 30f\n"
640 // Not all of the 4x4 block fits.
641 // Store to dst_tmp_buf
642 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
643 // Slow loop copying from dst_tmp_buf to dst.
644 "mov x3, %[dst_tmp_buf]\n"
645 "mov w6, #0\n"
646 "50:\n"
647 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
648 "mov w5, #0\n"
649 "51:\n"
650 "ldrb w7, [x3, w5, uxtw]\n"
651 "strb w7, [x4, w5, uxtw]\n"
652 "add w5, w5, #1\n"
653 "cmp w5, w1\n"
654 "blt 51b\n"
655 "add w6, w6, #1\n"
656 "add x3, x3, #4\n"
657 "add x4, x4, x11\n"
658 "cmp w6, w2\n"
659 "blt 50b\n"
660 "b 31f\n"
661 "30:\n"
662 // Yes, all of the 4x4 block fits.
663 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
664 "mov x3, x4\n"
665 "st1 {v16.b}[0], [x3], #1\n"
666 "add x4, x4, x11\n"
667 "st1 {v16.b}[1], [x3], #1\n"
668 "st1 {v16.b}[2], [x3], #1\n"
669 "st1 {v16.b}[3], [x3], #1\n"
670 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
671 "mov x3, x4\n"
672 "st1 {v16.b}[4], [x3], #1\n"
673 "add x4, x4, x11\n"
674 "st1 {v16.b}[5], [x3], #1\n"
675 "st1 {v16.b}[6], [x3], #1\n"
676 "st1 {v16.b}[7], [x3], #1\n"
677 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
678 "mov x3, x4\n"
679 "st1 {v16.b}[8], [x3], #1\n"
680 "add x4, x4, x11\n"
681 "st1 {v16.b}[9], [x3], #1\n"
682 "st1 {v16.b}[10], [x3], #1\n"
683 "st1 {v16.b}[11], [x3], #1\n"
684 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
685 "mov x3, x4\n"
686 "st1 {v16.b}[12], [x3], #1\n"
687 "add x4, x4, x11\n"
688 "st1 {v16.b}[13], [x3], #1\n"
689 "st1 {v16.b}[14], [x3], #1\n"
690 "st1 {v16.b}[15], [x3], #1\n"
691 "31:\n"
692
693 "add %[dst_ptr], %[dst_ptr], #4\n"
694
695 RUY_MAKE_ZERO(v16)
696 RUY_MAKE_ZERO(v17)
697
698 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
699
700 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
701
702 // Cast-and-saturate from int32 to int16
703 "sqxtn v16.4h, v16.4s\n"
704 "sqxtn2 v16.8h, v17.4s\n"
705 "sqxtn v17.4h, v18.4s\n"
706 "sqxtn2 v17.8h, v19.4s\n"
707
708 // At this point, v18 -- v31 aren't used anymore for the current block,
709 // so we can start clearing these accumulators for the next block
710 // (next iteration of the main loop).
711 RUY_MAKE_ZERO(v18)
712 RUY_MAKE_ZERO(v19)
713 RUY_MAKE_ZERO(v20)
714 RUY_MAKE_ZERO(v21)
715 RUY_MAKE_ZERO(v22)
716 RUY_MAKE_ZERO(v23)
717 RUY_MAKE_ZERO(v24)
718 RUY_MAKE_ZERO(v25)
719 RUY_MAKE_ZERO(v26)
720 RUY_MAKE_ZERO(v27)
721 RUY_MAKE_ZERO(v28)
722 RUY_MAKE_ZERO(v29)
723 RUY_MAKE_ZERO(v30)
724 RUY_MAKE_ZERO(v31)
725
726 // Add the destination zero point
727 "dup v14.8h, v13.h[4]\n"
728 "add v16.8h, v16.8h, v14.8h\n"
729 "add v17.8h, v17.8h, v14.8h\n"
730
731 // Cast-and-saturate from int16 to int8
732 "sqxtn v16.8b, v16.8h\n"
733 "sqxtn2 v16.16b, v17.8h\n"
734
735 // Load the clamp_min, clamp_max bounds
736 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
737 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
738 "dup v14.16b, w2\n" // clamp_min
739 "dup v15.16b, w3\n" // clamp_max
740
741 // Apply the clamp_min bound
742 "smax v16.16b, v16.16b, v14.16b\n"
743 // Apply the clamp_max bound
744 "smin v16.16b, v16.16b, v15.16b\n"
745
746 // Compute how much of the 4x4 block of destination 8bit values that
747 // we have computed, fit in the destination matrix. Typically, all of
748 // it fits, but when the destination matrix shape is not a multiple
749 // of 4x4, there are some 4x4 blocks along the boundaries that do
750 // not fit entirely.
751 "sub w1, %w[dst_rows], %w[row]\n"
752 "sub w2, %w[dst_cols], %w[col]\n"
753 "mov w3, #4\n"
754 "cmp w1, #4\n"
755 // Compute w1 = how many rows of the 4x4 block fit
756 "csel w1, w1, w3, le\n"
757 "cmp w2, #4\n"
758 // Compute w2 = how many cols of the 4x4 block fit
759 "csel w2, w2, w3, le\n"
760
761 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
762 "cmp w1, w3\n"
763 "ccmp w2, w3, 0, eq\n"
764 "mov x4, %[dst_ptr]\n"
765 // Yes, all of the 4x4 block fits, go to fast path.
766 "beq 30f\n"
767 // Not all of the 4x4 block fits.
768 // Store to dst_tmp_buf
769 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
770 // Slow loop copying from dst_tmp_buf to dst.
771 "mov x3, %[dst_tmp_buf]\n"
772 "mov w6, #0\n"
773 "50:\n"
774 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
775 "mov w5, #0\n"
776 "51:\n"
777 "ldrb w7, [x3, w5, uxtw]\n"
778 "strb w7, [x4, w5, uxtw]\n"
779 "add w5, w5, #1\n"
780 "cmp w5, w1\n"
781 "blt 51b\n"
782 "add w6, w6, #1\n"
783 "add x3, x3, #4\n"
784 "add x4, x4, x11\n"
785 "cmp w6, w2\n"
786 "blt 50b\n"
787 "b 31f\n"
788 "30:\n"
789 // Yes, all of the 4x4 block fits.
790 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
791 "mov x3, x4\n"
792 "st1 {v16.b}[0], [x3], #1\n"
793 "add x4, x4, x11\n"
794 "st1 {v16.b}[1], [x3], #1\n"
795 "st1 {v16.b}[2], [x3], #1\n"
796 "st1 {v16.b}[3], [x3], #1\n"
797 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
798 "mov x3, x4\n"
799 "st1 {v16.b}[4], [x3], #1\n"
800 "add x4, x4, x11\n"
801 "st1 {v16.b}[5], [x3], #1\n"
802 "st1 {v16.b}[6], [x3], #1\n"
803 "st1 {v16.b}[7], [x3], #1\n"
804 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
805 "mov x3, x4\n"
806 "st1 {v16.b}[8], [x3], #1\n"
807 "add x4, x4, x11\n"
808 "st1 {v16.b}[9], [x3], #1\n"
809 "st1 {v16.b}[10], [x3], #1\n"
810 "st1 {v16.b}[11], [x3], #1\n"
811 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
812 "mov x3, x4\n"
813 "st1 {v16.b}[12], [x3], #1\n"
814 "add x4, x4, x11\n"
815 "st1 {v16.b}[13], [x3], #1\n"
816 "st1 {v16.b}[14], [x3], #1\n"
817 "st1 {v16.b}[15], [x3], #1\n"
818 "31:\n"
819
820 "add %[dst_ptr], %[dst_ptr], #4\n"
821
822 RUY_MAKE_ZERO(v16)
823 RUY_MAKE_ZERO(v17)
824
825 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
826
827 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
828
829 // Add the destination zero point
830 "dup v14.4h, v13.h[4]\n"
831 "saddw v16.4s, v16.4s, v14.4h\n"
832 "saddw v17.4s, v17.4s, v14.4h\n"
833 "saddw v18.4s, v18.4s, v14.4h\n"
834 "saddw v19.4s, v19.4s, v14.4h\n"
835
836 // Cast-and-saturate from int32 to int16
837 "sqxtn v16.4h, v16.4s\n"
838 "sqxtn2 v16.8h, v17.4s\n"
839 "sqxtn v17.4h, v18.4s\n"
840 "sqxtn2 v17.8h, v19.4s\n"
841
842 // At this point, v18 -- v31 aren't used anymore for the current block,
843 // so we can start clearing these accumulators for the next block
844 // (next iteration of the main loop).
845 RUY_MAKE_ZERO(v18)
846 RUY_MAKE_ZERO(v19)
847 RUY_MAKE_ZERO(v20)
848 RUY_MAKE_ZERO(v21)
849 RUY_MAKE_ZERO(v22)
850 RUY_MAKE_ZERO(v23)
851 RUY_MAKE_ZERO(v24)
852 RUY_MAKE_ZERO(v25)
853 RUY_MAKE_ZERO(v26)
854 RUY_MAKE_ZERO(v27)
855 RUY_MAKE_ZERO(v28)
856 RUY_MAKE_ZERO(v29)
857 RUY_MAKE_ZERO(v30)
858 RUY_MAKE_ZERO(v31)
859
860 // Load the clamp_min, clamp_max bounds
861 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
862 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
863 "dup v14.8h, w2\n" // clamp_min
864 "dup v15.8h, w3\n" // clamp_max
865
866 // Apply the clamp_min bound
867 "smax v16.8h, v16.8h, v14.8h\n"
868 "smax v17.8h, v17.8h, v14.8h\n"
869 // Apply the clamp_max bound
870 "smin v16.8h, v16.8h, v15.8h\n"
871 "smin v17.8h, v17.8h, v15.8h\n"
872
873 // Compute how much of the 4x4 block of destination 8bit values that
874 // we have computed, fit in the destination matrix. Typically, all of
875 // it fits, but when the destination matrix shape is not a multiple
876 // of 4x4, there are some 4x4 blocks along the boundaries that do
877 // not fit entirely.
878 "sub w1, %w[dst_rows], %w[row]\n"
879 "sub w2, %w[dst_cols], %w[col]\n"
880 "mov w3, #4\n"
881 "cmp w1, #4\n"
882 // Compute w1 = how many rows of the 4x4 block fit
883 "csel w1, w1, w3, le\n"
884 "cmp w2, #4\n"
885 // Compute w2 = how many cols of the 4x4 block fit
886 "csel w2, w2, w3, le\n"
887
888 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
889 "cmp w1, w3\n"
890 "ccmp w2, w3, 0, eq\n"
891 "mov x4, %[dst_ptr]\n"
892 // Yes, all of the 4x4 block fits, go to fast path.
893 "beq 30f\n"
894 // Not all of the 4x4 block fits.
895 // Store to dst_tmp_buf
896 "str q16, [%[dst_tmp_buf], #0]\n"
897 "str q17, [%[dst_tmp_buf], #16]\n"
898 // Slow loop copying from dst_tmp_buf to dst.
899 "mov x3, %[dst_tmp_buf]\n"
900 "mov w6, #0\n"
901 "50:\n"
902 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
903 "mov w5, #0\n"
904 "51:\n"
905 "ldrh w7, [x3, x5, lsl #1]\n"
906 "strh w7, [x4, x5, lsl #1]\n"
907 "add w5, w5, #1\n"
908 "cmp w5, w1\n"
909 "blt 51b\n"
910 "add w6, w6, #1\n"
911 "add x3, x3, #8\n"
912 "add x4, x4, x11\n"
913 "cmp w6, w2\n"
914 "blt 50b\n"
915 "b 31f\n"
916 "30:\n"
917 // Yes, all of the 4x4 block fits.
918 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
919 "mov x3, x4\n"
920 "st1 {v16.h}[0], [x3], #2\n"
921 "add x4, x4, x11\n"
922 "st1 {v16.h}[1], [x3], #2\n"
923 "st1 {v16.h}[2], [x3], #2\n"
924 "st1 {v16.h}[3], [x3], #2\n"
925 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
926 "mov x3, x4\n"
927 "st1 {v16.h}[4], [x3], #2\n"
928 "add x4, x4, x11\n"
929 "st1 {v16.h}[5], [x3], #2\n"
930 "st1 {v16.h}[6], [x3], #2\n"
931 "st1 {v16.h}[7], [x3], #2\n"
932 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
933 "mov x3, x4\n"
934 "st1 {v17.h}[0], [x3], #2\n"
935 "add x4, x4, x11\n"
936 "st1 {v17.h}[1], [x3], #2\n"
937 "st1 {v17.h}[2], [x3], #2\n"
938 "st1 {v17.h}[3], [x3], #2\n"
939 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
940 "mov x3, x4\n"
941 "st1 {v17.h}[4], [x3], #2\n"
942 "add x4, x4, x11\n"
943 "st1 {v17.h}[5], [x3], #2\n"
944 "st1 {v17.h}[6], [x3], #2\n"
945 "st1 {v17.h}[7], [x3], #2\n"
946 "31:\n"
947
948 "add %[dst_ptr], %[dst_ptr], #8\n"
949
950 RUY_MAKE_ZERO(v16)
951 RUY_MAKE_ZERO(v17)
952
953 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
954
955 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
956
957 // Since the store type is the same as the accum type, no need for
958 // downcast. There's also no need for clamp by min/max.
959
960 // At this point, v20 -- v31 aren't used anymore for the current block,
961 // so we can start clearing these accumulators for the next block
962 // (next iteration of the main loop).
963 RUY_MAKE_ZERO(v20)
964 RUY_MAKE_ZERO(v21)
965 RUY_MAKE_ZERO(v22)
966 RUY_MAKE_ZERO(v23)
967 RUY_MAKE_ZERO(v24)
968 RUY_MAKE_ZERO(v25)
969 RUY_MAKE_ZERO(v26)
970 RUY_MAKE_ZERO(v27)
971 RUY_MAKE_ZERO(v28)
972 RUY_MAKE_ZERO(v29)
973 RUY_MAKE_ZERO(v30)
974 RUY_MAKE_ZERO(v31)
975
976 // Compute how much of the 4x4 block of destination 8bit values that
977 // we have computed, fit in the destination matrix. Typically, all of
978 // it fits, but when the destination matrix shape is not a multiple
979 // of 4x4, there are some 4x4 blocks along the boundaries that do
980 // not fit entirely.
981 "sub w1, %w[dst_rows], %w[row]\n"
982 "sub w2, %w[dst_cols], %w[col]\n"
983 "mov w3, #4\n"
984 "cmp w1, #4\n"
985 // Compute w1 = how many rows of the 4x4 block fit
986 "csel w1, w1, w3, le\n"
987 "cmp w2, #4\n"
988 // Compute w2 = how many cols of the 4x4 block fit
989 "csel w2, w2, w3, le\n"
990
991 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
992 "cmp w1, w3\n"
993 "ccmp w2, w3, 0, eq\n"
994 "mov x4, %[dst_ptr]\n"
995 // Yes, all of the 4x4 block fits, go to fast path.
996 "beq 30f\n"
997 // Not all of the 4x4 block fits.
998 // Store to dst_tmp_buf
999 "str q16, [%[dst_tmp_buf], #0]\n"
1000 "str q17, [%[dst_tmp_buf], #16]\n"
1001 "str q18, [%[dst_tmp_buf], #32]\n"
1002 "str q19, [%[dst_tmp_buf], #48]\n"
1003 // Slow loop copying from dst_tmp_buf to dst.
1004 "mov x3, %[dst_tmp_buf]\n"
1005 "mov w6, #0\n"
1006 "50:\n"
1007 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1008 "mov w5, #0\n"
1009 "51:\n"
1010 "ldr w7, [x3, x5, lsl #2]\n"
1011 "str w7, [x4, x5, lsl #2]\n"
1012 "add w5, w5, #1\n"
1013 "cmp w5, w1\n"
1014 "blt 51b\n"
1015 "add w6, w6, #1\n"
1016 "add x3, x3, #16\n"
1017 "add x4, x4, x11\n"
1018 "cmp w6, w2\n"
1019 "blt 50b\n"
1020 "b 31f\n"
1021 "30:\n"
1022 // Yes, all of the 4x4 block fits.
1023 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1024 "mov x3, x4\n"
1025 "st1 {v16.s}[0], [x3], #4\n"
1026 "add x4, x4, x11\n"
1027 "st1 {v16.s}[1], [x3], #4\n"
1028 "st1 {v16.s}[2], [x3], #4\n"
1029 "st1 {v16.s}[3], [x3], #4\n"
1030 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1031 "mov x3, x4\n"
1032 "st1 {v17.s}[0], [x3], #4\n"
1033 "add x4, x4, x11\n"
1034 "st1 {v17.s}[1], [x3], #4\n"
1035 "st1 {v17.s}[2], [x3], #4\n"
1036 "st1 {v17.s}[3], [x3], #4\n"
1037 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1038 "mov x3, x4\n"
1039 "st1 {v18.s}[0], [x3], #4\n"
1040 "add x4, x4, x11\n"
1041 "st1 {v18.s}[1], [x3], #4\n"
1042 "st1 {v18.s}[2], [x3], #4\n"
1043 "st1 {v18.s}[3], [x3], #4\n"
1044 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1045 "mov x3, x4\n"
1046 "st1 {v19.s}[0], [x3], #4\n"
1047 "add x4, x4, x11\n"
1048 "st1 {v19.s}[1], [x3], #4\n"
1049 "st1 {v19.s}[2], [x3], #4\n"
1050 "st1 {v19.s}[3], [x3], #4\n"
1051 "31:\n"
1052
1053 "add %[dst_ptr], %[dst_ptr], #16\n"
1054
1055 RUY_MAKE_ZERO(v16)
1056 RUY_MAKE_ZERO(v17)
1057 RUY_MAKE_ZERO(v18)
1058 RUY_MAKE_ZERO(v19)
1059
1060 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1061
1062 // For the next block: perform the first few multiply-adds on the data
1063 // that we have already loaded.
1064 "smull v8.8h, v0.8b, v4.8b\n"
1065 "smull v9.8h, v1.8b, v4.8b\n"
1066 "smull v10.8h, v2.8b, v4.8b\n"
1067 "smull v11.8h, v3.8b, v4.8b\n"
1068 "smull v12.8h, v0.8b, v5.8b\n"
1069 "smull v13.8h, v1.8b, v5.8b\n"
1070 "smull v14.8h, v2.8b, v5.8b\n"
1071 "smull v15.8h, v3.8b, v5.8b\n"
1072 "smlal2 v8.8h, v0.16b, v4.16b\n"
1073 "smlal2 v9.8h, v1.16b, v4.16b\n"
1074 "smlal2 v10.8h, v2.16b, v4.16b\n"
1075 "smlal2 v11.8h, v3.16b, v4.16b\n"
1076 "smlal2 v12.8h, v0.16b, v5.16b\n"
1077 "smlal2 v13.8h, v1.16b, v5.16b\n"
1078 "smlal2 v14.8h, v2.16b, v5.16b\n"
1079 "smlal2 v15.8h, v3.16b, v5.16b\n"
1080
1081 // Reload some params --- we had used x5 -- x7 for a few other things
1082 // since the last time we had loaded them.
1083 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1084 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1085 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1086
1087 // Move to the next block of the destination matrix, for the next iter
1088 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1089 // been updated earlier.
1090 // Have we reached the end row?
1091 "cmp %w[row], w7\n"
1092 "beq 20f\n" // yes, end row.
1093 // Not end row. Move to the next row.
1094 "add %w[row], %w[row], #4\n"
1095 "b 21f\n"
1096 "20:\n"
1097 // Was already at end row.
1098 "mov %w[row], w6\n" // Move back to first row.
1099 "add %w[col], %w[col], #4\n" // Move to the next column.
1100 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
1101 "mov %[dst_ptr], %[dst_col_ptr]\n"
1102 "21:\n"
1103
1104 // Main loop exit condition: have we hit the end column?
1105 "cmp %w[col], w8\n"
1106
1107 // w1 is the number of levels of depth that we have already loaded
1108 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1109 // above, this is currently 4.
1110 "mov w1, #16\n"
1111
1112 "ble 1b\n"
1113
1114 // clang-format on
1115
1116 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
1117 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1118 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
1119 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
1120 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
1121 [dst_type_id] "r"(params.dst_type_id)
1122 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
1123 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1124 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
1125 "v26", "v27", "v28", "v29", "v30", "v31");
1126 }
1127
1128 // Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of
1129 // RHS cols == 1.
1130 // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
1131 // since these are 64-bit, out-of-order and without dotprod support.
Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4,4> & params)1132 void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) {
1133 profiler::ScopeLabel label(
1134 "Kernel (kNeon, optimized for out-of-order cores)");
1135
1136 CheckOffsetsInKernelParams8bit(params);
1137
1138 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1139 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1140 const std::int8_t* lhs_ptr = lhs_col_ptr;
1141 const std::int8_t* rhs_ptr = rhs_col_ptr;
1142 void* dst_col_ptr = params.dst_base_ptr;
1143 void* dst_ptr = dst_col_ptr;
1144 int row = params.start_row;
1145 int col = params.start_col;
1146
1147 // The asm kernel below has the following NEON register allocation:
1148 //
1149 // v16 -- v19 are int32 accumulators.
1150 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
1151 // v4 from RHS:
1152 //
1153 // int8 RHS 16x1 block
1154 // /-----------\
1155 // |v4.b[0] |
1156 // | ... |
1157 // |v4.b[15] |
1158 // \-----------/
1159 // int8 LHS 4x16 block
1160 // /---------------------\ /-----------\
1161 // |v0.b[0] ... v0.b[15] | |v16.4s |
1162 // |v1.b[0] ... v1.b[15] | |v17.4s |
1163 // |v2.b[0] ... v2.b[15] | |v18.4s |
1164 // |v3.b[0] ... v3.b[15] | |v19.4s |
1165 // \---------------------/ \-----------/
1166 // int32 accumulators 4x1 block
1167 //
1168 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1169 // optimization for this kernel.
1170 asm volatile(
1171 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
1172
1173 // clang-format off
1174
1175 // Load some parameters into registers.
1176 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1177 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1178 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1179 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1180 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1181 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1182 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1183 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1184
1185 // Load the first 64 bytes of LHS and RHS data.
1186 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1187 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1188 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1189 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1190 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1191 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1192
1193 // Clear accumulators.
1194 RUY_MAKE_ZERO(v16)
1195 RUY_MAKE_ZERO(v17)
1196 RUY_MAKE_ZERO(v18)
1197 RUY_MAKE_ZERO(v19)
1198
1199 // w1 is the number of levels of depth that we have already loaded
1200 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1201 // above, this is currently 16.
1202 "mov w1, #16\n"
1203
1204 // Perform the first few multiply-adds on the data that we have already
1205 // loaded.
1206 "smull v8.8h, v0.8b, v4.8b\n"
1207 "smull v9.8h, v1.8b, v4.8b\n"
1208 "smull v10.8h, v2.8b, v4.8b\n"
1209 "smull v11.8h, v3.8b, v4.8b\n"
1210
1211 // Multiply-accumulate second-half, again into the same
1212 // 16bit local accumulator registers. This is where we
1213 // take advantage of having int8 instead of uint8 and therefore
1214 // being able to accumulate two products into int16.
1215 "smlal2 v8.8h, v0.16b, v4.16b\n"
1216 "smlal2 v9.8h, v1.16b, v4.16b\n"
1217 "smlal2 v10.8h, v2.16b, v4.16b\n"
1218 "smlal2 v11.8h, v3.16b, v4.16b\n"
1219
1220 // Main loop of the whole GEMM, over rows and columns of the
1221 // destination matrix.
1222 "1:\n"
1223
1224 // Reminder - w1 is how many levels of depth we have already loaded
1225 // data for, w12 is the total depth.
1226 "cmp w1, w12\n"
1227 "beq 79f\n"
1228
1229 "2:\n"
1230
1231 // Some multiplications and 16-bit accumulation were already done above,
1232 // so we start right away in the middle.
1233 "sadalp v16.4s, v8.8h\n"
1234 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1235 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1236 "sadalp v17.4s, v9.8h\n"
1237 "sadalp v18.4s, v10.8h\n"
1238 "sadalp v19.4s, v11.8h\n"
1239
1240 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1241 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1242 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1243 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1244
1245 "smull v8.8h, v0.8b, v4.8b\n"
1246 "smull v9.8h, v1.8b, v4.8b\n"
1247 "smull v10.8h, v2.8b, v4.8b\n"
1248 "smull v11.8h, v3.8b, v4.8b\n"
1249
1250 // Multiply-accumulate second-half, again into the same
1251 // 16bit local accumulator registers. This is where we
1252 // take advantage of having int8 instead of uint8 and therefore
1253 // being able to accumulate two products into int16.
1254 "smlal2 v8.8h, v0.16b, v4.16b\n"
1255 "smlal2 v9.8h, v1.16b, v4.16b\n"
1256 "smlal2 v10.8h, v2.16b, v4.16b\n"
1257 "smlal2 v11.8h, v3.16b, v4.16b\n"
1258
1259 // Each iteration of this loop advances by 16 levels of depth.
1260 "add w1, w1, #16\n"
1261
1262 // Loop termination condition
1263 "cmp w1, w12\n"
1264
1265 "blt 2b\n"
1266
1267 "79:\n"
1268
1269 "sadalp v16.4s, v8.8h\n"
1270 "sadalp v17.4s, v9.8h\n"
1271 "sadalp v18.4s, v10.8h\n"
1272 "sadalp v19.4s, v11.8h\n"
1273
1274 // End of accumulation. The registers v16 -- v19 contain the final
1275 // int32 accumulator values of the current 4x1 destination block.
1276 // We now have to compute the final 8-bit values from these int32
1277 // accumulators, and advance to the next 4x1 block. We intertwine
1278 // these two aspects whenever possible for optimal pipelining, both
1279 // at the data flow level (prefetch data for next block as early as
1280 // possible) and instruction pipelining level (some of the next-block
1281 // work can dual-issue with some of the final work on the current
1282 // block).
1283
1284 // Reduce 32bit accumulators horizontally.
1285 "addp v16.4s, v16.4s, v17.4s\n"
1286 "addp v18.4s, v18.4s, v19.4s\n"
1287
1288 // Reduce 32bit accumulators horizontally, second pass
1289 // (each pass adds pairwise. we need to add 4-wise).
1290 "addp v16.4s, v16.4s, v18.4s\n"
1291
1292 // Logic to advance to the next block in preparation for the next
1293 // iteration of the main loop. For now, we only want to compute
1294 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1295 // not yet ready to update the values of row and col, as we still need
1296 // the current values for the rest of the work on the current block.
1297
1298 "cmp %w[row], w7\n" // Have we finished the last row?
1299 "bge 4f\n" // If finished last row, go to 4
1300 // Not finished last row: then advance to next row.
1301 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
1302 "b 5f\n"
1303 "4:\n" // Finished last row...
1304 "mov %[lhs_col_ptr], x5\n" // Go back to first row
1305 // Now we need to advance to the next column. If we already
1306 // finished the last column, then in principle we are done, however
1307 // we can't just return here, as we need to allow the end work of the
1308 // current block to complete. The good news is that at this point it
1309 // doesn't matter what data we load for the next column, since
1310 // we will exit from the main loop below before actually storing
1311 // anything computed from that data.
1312 "cmp %w[col], w8\n" // Have we finished the last column?
1313 "bge 5f\n" // If yes, just carry on without updating the column pointer.
1314 // Not finished last column: then advance to next column.
1315 // (still multiply column stride by 4 due to packing)
1316 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
1317 "5:\n"
1318
1319 // Set the LHS and RHS data pointers to the start of the columns just
1320 // computed.
1321 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
1322 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
1323
1324 // Load some parameters needed for the end work on current block.
1325 RUY_MAKE_ZERO(v8)
1326 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1327 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1328 "ins v13.h[4], w4\n" // dst_zero_point
1329 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1330 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1331 "dup v9.4s, w3\n" // create prod_zp_depth_vec
1332 "add x5, x4, %x[row], lsl #2\n"
1333 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1334 "csel x4, x4, x5, eq\n"
1335
1336 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
1337
1338 // Now we load: bias data, LHS sums data, RHS sums data.
1339
1340 // First, load the base pointers from the params.
1341 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1342
1343 "add x5, x1, %x[row], lsl #2\n"
1344 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1345 "csel x1, x1, x5, eq\n"
1346
1347 // Load 4 bias values.
1348 "ld1 {v14.4s}, [x1]\n"
1349
1350 // Now that we know what LHS and RHS data the next iteration of the
1351 // main loop will need to load, we start loading the first 32 bytes of
1352 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1353 // in the rest of the work on the current block.
1354 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1355 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1356 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1357 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1358 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1359 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1360
1361 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
1362 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1363 "add v14.4s, v14.4s, v9.4s\n"
1364
1365 // Perform the bias-addition (per the above, we have just folded into
1366 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1367 // (all four 32-bit accumulators are in v16 at this point)
1368 "add v16.4s, v16.4s, v14.4s\n"
1369
1370 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1371 "beq 401f\n"
1372 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1373 "add x3, x3, %x[col], lsl #2\n"
1374 "ld1 {v14.4s}, [x3]\n"
1375 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1376 "dup v10.4s, w5\n" // create lhs_zero_point_vec
1377 // Subtract rhs_sums * lhs_zero_point, per
1378 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1379 "mls v16.4s, v10.4s, v14.s[0]\n"
1380 "401:\n"
1381
1382 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1383 "beq 402f\n"
1384 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1385 "add x2, x2, %x[row], lsl #2\n"
1386 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1387 // Load 4 lhs_sums values.
1388 "ld1 {v11.4s}, [x2]\n"
1389 "ins v13.s[1], w5\n" // rhs_zero_point
1390 // Compute lhs_sums * rhs_zero_point.
1391 "mul v11.4s, v11.4s, v13.s[1]\n"
1392 // Subtract lhs_sums * rhs_zero_point, per
1393 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1394 "sub v16.4s, v16.4s, v11.4s\n"
1395
1396 // If the destination is int32, it means the user asks for the raw
1397 // accumulators, no need for us to downquantize the value.
1398 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1399 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1400
1401 "402:\n"
1402
1403 // At this point we have computed the final int32 values. Now we
1404 // start down-quantizing them to obtain the final 8bit values from them.
1405
1406 // As part of this down-quantization, our int32 values will be
1407 // multiplied by a multiplier that has a fixed-point component and an
1408 // exponent component.
1409
1410 //Load the exponent part of the multiplier.
1411 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1412 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1413 "add x5, x1, %x[row], lsl #2\n"
1414 "csel x1, x1, x5, eq\n"
1415
1416 "ld1 {v14.4s}, [x1]\n"
1417
1418 "smax v12.4s, v14.4s, v8.4s\n"
1419
1420 "sshl v16.4s, v16.4s, v12.4s\n"
1421
1422 "smin v12.4s, v14.4s, v8.4s\n"
1423
1424 // Apply the fixed-point part of the multiplier.
1425 "sqrdmulh v16.4s, v16.4s, v15.4s\n"
1426
1427 // We have some rounding division-by-power-of-two to do. This should
1428 // always use "round to nearest". We allow for some
1429 // freedom in how ties are broken, to strike a good compromise of
1430 // performance on given hardware vs. perfect agreement of results
1431 // across hardware.
1432 //
1433 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
1434 // defined tie-breaks to help performance. On NEON, this means that we
1435 // can just use the NEON rounding instructions, such as srshl. They
1436 // happen to be breaking ties upward.
1437 //
1438 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
1439 // break-ties-away-from zero, as described in Appendix B of
1440 // https://arxiv.org/pdf/1712.05877.pdf
1441 // When we wrote that, we thought that that would be better unbiased
1442 // than the NEON upwards tie-breaks, and we had observed some
1443 // improvement on some model. However, that is only more unbiased for
1444 // data centered at zero, which was likely the case in that model,
1445 // but is not always the case. If we wanted something more consistently
1446 // unbiased then we should try breaking ties toward-nearest-even.
1447 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
1448 // Fix up values to be right-shifted, so that the (round to nearest,
1449 // break ties upward) behavior of srshl applied to these fixed-up
1450 // values, produces the same result as the desired (round to nearest,
1451 // break ties away from zero) behavior on the original values.
1452 "and v8.16b, v16.16b, v12.16b\n"
1453 "sshr v8.4s, v8.4s, #31\n"
1454 "sqadd v16.4s, v16.4s, v8.4s\n"
1455 #endif
1456 // At this point we have reduced the problem of correctly implementing
1457 // rounding divide-by-power-of-two, to what the SRSHL instruction can
1458 // do.
1459 "srshl v16.4s, v16.4s, v12.4s\n"
1460
1461 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1462 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1463 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1464 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1465
1466 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1467
1468 // Cast-and-saturate from int32 to int16
1469 // After this instruction, all data is in lower half (64-bits) of v16
1470 "sqxtn v16.4h, v16.4s\n"
1471
1472 // At this point, v18 -- v31 aren't used anymore for the current block,
1473 // so we can start clearing these accumulators for the next block
1474 // (next iteration of the main loop).
1475 RUY_MAKE_ZERO(v18)
1476 RUY_MAKE_ZERO(v19)
1477
1478 // Add the destination zero point
1479 "dup v14.8h, v13.h[4]\n"
1480 "add v16.8h, v16.8h, v14.8h\n"
1481
1482 // Cast-and-saturate from int16 to uint8
1483 // Now all data is in the first 32-bits of v16
1484 "sqxtun v16.8b, v16.8h\n"
1485
1486 // Load the clamp_min, clamp_max bounds
1487 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1488 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1489 "dup v14.16b, w2\n" // clamp_min
1490 "dup v15.16b, w3\n" // clamp_max
1491
1492 // Apply the clamp_min bound
1493 "umax v16.16b, v16.16b, v14.16b\n"
1494 // Apply the clamp_max bound
1495 "umin v16.16b, v16.16b, v15.16b\n"
1496
1497 // Compute how much of the 4x1 block of destination 8bit values that
1498 // we have computed, fit in the destination matrix. Typically, all of
1499 // it fits, but when the destination matrix shape is not a multiple
1500 // of 4x1, there are some 4x1 blocks along the boundaries that do
1501 // not fit entirely.
1502 "sub w1, %w[dst_rows], %w[row]\n"
1503 "mov w3, #4\n"
1504 "cmp w1, #4\n"
1505 // Compute w1 = how many rows of the 4x1 block fit
1506 "csel w1, w1, w3, le\n"
1507
1508 // Test if w1==4, i.e. if all of the 4x1 block fits.
1509 "cmp w1, w3\n"
1510
1511 "mov x4, %[dst_ptr]\n"
1512 // Yes, all of the 4x1 block fits, go to fast path.
1513 "beq 30f\n"
1514 // Not all of the 4x1 block fits.
1515 // Store to dst_tmp_buf
1516 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
1517 // Slow loop copying from dst_tmp_buf to dst.
1518 "mov x3, %[dst_tmp_buf]\n"
1519 "mov w6, #0\n"
1520 "50:\n"
1521 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1522 "mov w5, #0\n"
1523 "51:\n"
1524 "ldrb w7, [x3, w5, uxtw]\n"
1525 "strb w7, [x4, w5, uxtw]\n"
1526 "add w5, w5, #1\n"
1527 "cmp w5, w1\n"
1528 "blt 51b\n"
1529 "b 31f\n"
1530 "30:\n"
1531 // Yes, all of the 4x1 block fits.
1532 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1533 "mov x3, x4\n"
1534 "st1 {v16.b}[0], [x3], #1\n"
1535 "st1 {v16.b}[1], [x3], #1\n"
1536 "st1 {v16.b}[2], [x3], #1\n"
1537 "st1 {v16.b}[3], [x3], #1\n"
1538 "31:\n"
1539
1540 "add %[dst_ptr], %[dst_ptr], #4\n"
1541
1542 RUY_MAKE_ZERO(v16)
1543 RUY_MAKE_ZERO(v17)
1544
1545 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1546
1547 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1548
1549 // Cast-and-saturate from int32 to int16
1550 // After this, all values for output are in the lower half (64 bits) of v16.
1551 "sqxtn v16.4h, v16.4s\n"
1552
1553 // At this point, v18 -- v31 aren't used anymore for the current block,
1554 // so we can start clearing these accumulators for the next block
1555 // (next iteration of the main loop).
1556 RUY_MAKE_ZERO(v18)
1557 RUY_MAKE_ZERO(v19)
1558
1559 // Add the destination zero point
1560 "dup v14.8h, v13.h[4]\n"
1561 "add v16.8h, v16.8h, v14.8h\n"
1562
1563 // Cast-and-saturate from int16 to int8
1564 "sqxtn v16.8b, v16.8h\n"
1565 // At this point, we only need 4 lowest 8-bit values in v16.
1566
1567 // Load the clamp_min, clamp_max bounds
1568 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1569 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1570 "dup v14.16b, w2\n" // clamp_min
1571 "dup v15.16b, w3\n" // clamp_max
1572
1573 // Apply the clamp_min bound
1574 "smax v16.16b, v16.16b, v14.16b\n"
1575 // Apply the clamp_max bound
1576 "smin v16.16b, v16.16b, v15.16b\n"
1577
1578 // Compute how much of the 4x4 block of destination 8bit values that
1579 // we have computed, fit in the destination matrix. Typically, all of
1580 // it fits, but when the destination matrix shape is not a multiple
1581 // of 4x4, there are some 4x4 blocks along the boundaries that do
1582 // not fit entirely.
1583 "sub w1, %w[dst_rows], %w[row]\n"
1584 "sub w2, %w[dst_cols], %w[col]\n"
1585 "mov w3, #4\n"
1586 "cmp w1, #4\n"
1587 // Compute w1 = how many rows of the 4x1 block fit
1588 "csel w1, w1, w3, le\n"
1589 "cmp w2, #4\n"
1590
1591 // Test if w1==4, i.e. if all of the 4x1 block fits.
1592 "cmp w1, w3\n"
1593 "ccmp w2, w3, 0, eq\n"
1594 "mov x4, %[dst_ptr]\n"
1595 // Yes, all of the 4x1 block fits, go to fast path.
1596 "beq 30f\n"
1597 // Not all of the 4x4 block fits.
1598 // Store to dst_tmp_buf
1599 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
1600 // Slow loop copying from dst_tmp_buf to dst.
1601 "mov x3, %[dst_tmp_buf]\n"
1602 "mov w6, #0\n"
1603 "50:\n"
1604 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1605 "mov w5, #0\n"
1606 "51:\n"
1607 "ldrb w7, [x3, w5, uxtw]\n"
1608 "strb w7, [x4, w5, uxtw]\n"
1609 "add w5, w5, #1\n"
1610 "cmp w5, w1\n"
1611 "blt 51b\n"
1612 "b 31f\n"
1613 "30:\n"
1614 // Yes, all of the 4x4 block fits.
1615 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1616 "mov x3, x4\n"
1617 "st1 {v16.b}[0], [x3], #1\n"
1618 "st1 {v16.b}[1], [x3], #1\n"
1619 "st1 {v16.b}[2], [x3], #1\n"
1620 "st1 {v16.b}[3], [x3], #1\n"
1621 "31:\n"
1622
1623 "add %[dst_ptr], %[dst_ptr], #4\n"
1624
1625 RUY_MAKE_ZERO(v16)
1626 RUY_MAKE_ZERO(v17)
1627
1628 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1629
1630 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1631
1632 // Add the destination zero point
1633 "dup v14.4h, v13.h[4]\n"
1634 "saddw v16.4s, v16.4s, v14.4h\n"
1635
1636 // Cast-and-saturate from int32 to int16
1637 // After this instruction, all data is in lower half of v16.
1638 "sqxtn v16.4h, v16.4s\n"
1639
1640 // At this point, v18 -- v31 aren't used anymore for the current block,
1641 // so we can start clearing these accumulators for the next block
1642 // (next iteration of the main loop).
1643 RUY_MAKE_ZERO(v18)
1644 RUY_MAKE_ZERO(v19)
1645
1646 // Load the clamp_min, clamp_max bounds
1647 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1648 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1649 "dup v14.8h, w2\n" // clamp_min
1650 "dup v15.8h, w3\n" // clamp_max
1651
1652 // Apply the clamp_min bound
1653 "smax v16.8h, v16.8h, v14.8h\n"
1654 // Apply the clamp_max bound
1655 "smin v16.8h, v16.8h, v15.8h\n"
1656
1657 // Compute how much of the 4x4 block of destination 8bit values that
1658 // we have computed, fit in the destination matrix. Typically, all of
1659 // it fits, but when the destination matrix shape is not a multiple
1660 // of 4x4, there are some 4x4 blocks along the boundaries that do
1661 // not fit entirely.
1662 "sub w1, %w[dst_rows], %w[row]\n"
1663 "sub w2, %w[dst_cols], %w[col]\n"
1664 "mov w3, #4\n"
1665 "cmp w1, #4\n"
1666 // Compute w1 = how many rows of the 4x4 block fit
1667 "csel w1, w1, w3, le\n"
1668 "cmp w2, #4\n"
1669
1670 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
1671 "cmp w1, w3\n"
1672 "mov x4, %[dst_ptr]\n"
1673 // Yes, all of the 4x4 block fits, go to fast path.
1674 "beq 30f\n"
1675 // Not all of the 4x4 block fits.
1676 // Store to dst_tmp_buf
1677 "str q16, [%[dst_tmp_buf], #0]\n"
1678 // Slow loop copying from dst_tmp_buf to dst.
1679 "mov x3, %[dst_tmp_buf]\n"
1680 "mov w6, #0\n"
1681 "50:\n"
1682 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1683 "mov w5, #0\n"
1684 "51:\n"
1685 "ldrh w7, [x3, x5, lsl #1]\n"
1686 "strh w7, [x4, x5, lsl #1]\n"
1687 "add w5, w5, #1\n"
1688 "cmp w5, w1\n"
1689 "blt 51b\n"
1690 "blt 50b\n"
1691 "b 31f\n"
1692 "30:\n"
1693 // Yes, all of the 4x4 block fits.
1694 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1695 "mov x3, x4\n"
1696 "st1 {v16.h}[0], [x3], #2\n"
1697 "st1 {v16.h}[1], [x3], #2\n"
1698 "st1 {v16.h}[2], [x3], #2\n"
1699 "st1 {v16.h}[3], [x3], #2\n"
1700 "31:\n"
1701
1702 "add %[dst_ptr], %[dst_ptr], #8\n"
1703
1704 RUY_MAKE_ZERO(v16)
1705 RUY_MAKE_ZERO(v17)
1706
1707 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1708
1709 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1710
1711 // Since the store type is the same as the accum type, no need for
1712 // downcast. There's also no need for clamp by min/max.
1713
1714 // Compute how much of the 4x4 block of destination 8bit values that
1715 // we have computed, fit in the destination matrix. Typically, all of
1716 // it fits, but when the destination matrix shape is not a multiple
1717 // of 4x4, there are some 4x4 blocks along the boundaries that do
1718 // not fit entirely.
1719 "sub w1, %w[dst_rows], %w[row]\n"
1720 "sub w2, %w[dst_cols], %w[col]\n"
1721 "mov w3, #4\n"
1722 "cmp w1, #4\n"
1723 // Compute w1 = how many rows of the 4x4 block fit
1724 "csel w1, w1, w3, le\n"
1725 "cmp w2, #4\n"
1726
1727 // Test if w1==4 i.e. if all of the 4x1 block fits.
1728 "cmp w1, w3\n"
1729 "ccmp w2, w3, 0, eq\n"
1730 "mov x4, %[dst_ptr]\n"
1731 // Yes, all of the 4x1 block fits, go to fast path.
1732 "beq 30f\n"
1733 // Not all of the 4x4 block fits.
1734 // Store to dst_tmp_buf
1735 "str q16, [%[dst_tmp_buf], #0]\n"
1736 // Slow loop copying from dst_tmp_buf to dst.
1737 "mov x3, %[dst_tmp_buf]\n"
1738 "mov w6, #0\n"
1739 "50:\n"
1740 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1741 "mov w5, #0\n"
1742 "51:\n"
1743 "ldr w7, [x3, x5, lsl #2]\n"
1744 "str w7, [x4, x5, lsl #2]\n"
1745 "add w5, w5, #1\n"
1746 "cmp w5, w1\n"
1747 "blt 51b\n"
1748 "b 31f\n"
1749 "30:\n"
1750 // Yes, all of the 4x4 block fits.
1751 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1752 "mov x3, x4\n"
1753 "st1 {v16.s}[0], [x3], #4\n"
1754 "st1 {v16.s}[1], [x3], #4\n"
1755 "st1 {v16.s}[2], [x3], #4\n"
1756 "st1 {v16.s}[3], [x3], #4\n"
1757 "31:\n"
1758
1759 "add %[dst_ptr], %[dst_ptr], #16\n"
1760
1761 RUY_MAKE_ZERO(v16)
1762 RUY_MAKE_ZERO(v17)
1763 RUY_MAKE_ZERO(v18)
1764 RUY_MAKE_ZERO(v19)
1765
1766 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1767
1768 // For the next block: perform the first few multiply-adds on the data
1769 // that we have already loaded.
1770 "smull v8.8h, v0.8b, v4.8b\n"
1771 "smull v9.8h, v1.8b, v4.8b\n"
1772 "smull v10.8h, v2.8b, v4.8b\n"
1773 "smull v11.8h, v3.8b, v4.8b\n"
1774 "smlal2 v8.8h, v0.16b, v4.16b\n"
1775 "smlal2 v9.8h, v1.16b, v4.16b\n"
1776 "smlal2 v10.8h, v2.16b, v4.16b\n"
1777 "smlal2 v11.8h, v3.16b, v4.16b\n"
1778
1779 // Reload some params --- we had used x5 -- x7 for a few other things
1780 // since the last time we had loaded them.
1781 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1782 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1783 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1784
1785 // Move to the next block of the destination matrix, for the next iter
1786 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1787 // been updated earlier.
1788 // Have we reached the end row?
1789 "cmp %w[row], w7\n"
1790 "beq 20f\n" // yes, end row.
1791 // Not end row. Move to the next row.
1792 "add %w[row], %w[row], #4\n"
1793 "b 21f\n"
1794 "20:\n"
1795 // Was already at end row.
1796 "mov %w[row], w6\n" // Move back to first row.
1797 "add %w[col], %w[col], #4\n" // Move to the next column.
1798 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
1799 "mov %[dst_ptr], %[dst_col_ptr]\n"
1800 "21:\n"
1801
1802 // Main loop exit condition: have we hit the end column?
1803 "cmp %w[col], w8\n"
1804
1805 // w1 is the number of levels of depth that we have already loaded
1806 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1807 // above, this is currently 16.
1808 "mov w1, #16\n"
1809
1810 "ble 1b\n"
1811
1812 // clang-format on
1813
1814 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
1815 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1816 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
1817 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
1818 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
1819 [dst_type_id] "r"(params.dst_type_id)
1820 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
1821 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1822 "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1823 }
1824
1825 // Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs.
1826 // Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
1827 // the original Cortex-A55, since these are 64-bit and do not support dotprod.
1828 //
1829 // While this kernel does not have a direct equivalent in gemmlowp, it was
1830 // developed based on insights that David Mansell at ARM shared with their
1831 // contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
1832 // comments. Specifically, see this comment about tuning for Cortex-A53:
1833 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
Kernel8bitNeonInOrder(const KernelParams8bit<4,4> & params)1834 void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) {
1835 profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
1836
1837 CheckOffsetsInKernelParams8bit(params);
1838
1839 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1840 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1841 const std::int8_t* lhs_ptr = lhs_col_ptr;
1842 const std::int8_t* rhs_ptr = rhs_col_ptr;
1843 void* dst_col_ptr = params.dst_base_ptr;
1844 void* dst_ptr = dst_col_ptr;
1845 int row = params.start_row;
1846 int col = params.start_col;
1847
1848 // The asm kernel below has the following NEON register allocation:
1849 //
1850 // v16 -- v31 are int32 accumulators.
1851 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
1852 // v4 -- v7 from RHS:
1853 //
1854 // int8 RHS 16x4 block
1855 // /-----------------------------------------\
1856 // |v4.b[0] ... v7.b[0] |
1857 // | ... ... |
1858 // |v4.b[15] ... v7.b[15] |
1859 // \-----------------------------------------/
1860 // int8 LHS 4x16 block
1861 // /---------------------\ /-----------------------------------------\
1862 // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
1863 // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
1864 // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
1865 // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
1866 // \---------------------/ \-----------------------------------------/
1867 // int32 accumulators 4x4 block
1868 asm volatile(
1869 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
1870
1871 // clang-format off
1872
1873 // Load some parameters into registers.
1874 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1875 RUY_MAKE_ZERO(v16)
1876 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1877 RUY_MAKE_ZERO(v17)
1878 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1879 RUY_MAKE_ZERO(v18)
1880 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1881 RUY_MAKE_ZERO(v19)
1882 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1883 RUY_MAKE_ZERO(v20)
1884 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1885 RUY_MAKE_ZERO(v21)
1886 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1887 RUY_MAKE_ZERO(v22)
1888 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1889 RUY_MAKE_ZERO(v23)
1890
1891 // Load the first 64 bytes of LHS and RHS data.
1892 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1893 RUY_MAKE_ZERO(v24)
1894 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1895 RUY_MAKE_ZERO(v25)
1896 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1897 RUY_MAKE_ZERO(v26)
1898 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1899 RUY_MAKE_ZERO(v27)
1900 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1901 RUY_MAKE_ZERO(v28)
1902 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
1903 RUY_MAKE_ZERO(v29)
1904 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
1905 RUY_MAKE_ZERO(v30)
1906 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
1907 RUY_MAKE_ZERO(v31)
1908
1909
1910 // w1 is the number of levels of depth that we have already loaded
1911 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1912 // above, this is currently 16.
1913 "mov w1, #16\n"
1914
1915 // Perform the first few multiply-adds on the data that we have already
1916 // loaded.
1917 "smull v8.8h, v0.8b, v4.8b\n"
1918 "smull v9.8h, v1.8b, v4.8b\n"
1919 "smull v10.8h, v2.8b, v4.8b\n"
1920 "smull v11.8h, v3.8b, v4.8b\n"
1921 "smull v12.8h, v0.8b, v5.8b\n"
1922 "smull v13.8h, v1.8b, v5.8b\n"
1923 "smull v14.8h, v2.8b, v5.8b\n"
1924 "smull v15.8h, v3.8b, v5.8b\n"
1925
1926 // Multiply-accumulate second-half, again into the same
1927 // 16bit local accumulator registers. This is where we
1928 // take advantage of having int8 instead of uint8 and therefore
1929 // being able to accumulate two products into int16.
1930 "smlal2 v8.8h, v0.16b, v4.16b\n"
1931 "smlal2 v9.8h, v1.16b, v4.16b\n"
1932 "smlal2 v10.8h, v2.16b, v4.16b\n"
1933 "smlal2 v11.8h, v3.16b, v4.16b\n"
1934 "smlal2 v12.8h, v0.16b, v5.16b\n"
1935 "smlal2 v13.8h, v1.16b, v5.16b\n"
1936 "smlal2 v14.8h, v2.16b, v5.16b\n"
1937 "smlal2 v15.8h, v3.16b, v5.16b\n"
1938
1939
1940 // Main loop of the whole GEMM, over rows and columns of the
1941 // destination matrix.
1942 "1:\n"
1943
1944 // Reminder - w1 is how many levels of depth we have already loaded
1945 // data for, w12 is the total depth.
1946 "cmp w1, w12\n"
1947 "beq 79f\n"
1948
1949 "2:\n"
1950
1951 // Some multiplications and 16-bit accumulation were already done above,
1952 // so we start right away in the middle.
1953 "sadalp v16.4s, v8.8h\n"
1954 "ldr d4, [%[rhs_ptr], #0]\n"
1955 "smull v8.8h, v0.8b, v6.8b\n"
1956 "ldr x7, [%[rhs_ptr], #8]\n"
1957 "sadalp v17.4s, v9.8h\n"
1958 "ldr d5, [%[rhs_ptr], #16]\n"
1959 "smull v9.8h, v1.8b, v6.8b\n"
1960 "ldr x8, [%[rhs_ptr], #24]\n"
1961 "sadalp v18.4s, v10.8h\n"
1962 "smull v10.8h, v2.8b, v6.8b\n"
1963 "sadalp v19.4s, v11.8h\n"
1964 "add %[lhs_ptr], %[lhs_ptr], #64\n"
1965 "smull v11.8h, v3.8b, v6.8b\n"
1966 "add %[rhs_ptr], %[rhs_ptr], #64\n"
1967 "sadalp v20.4s, v12.8h\n"
1968 // Each iteration of this loop advances by 16 levels of depth.
1969 "add w1, w1, #16\n"
1970 "smull v12.8h, v0.8b, v7.8b\n"
1971 // Loop termination condition
1972 "cmp w1, w12\n"
1973 "sadalp v21.4s, v13.8h\n"
1974 "ldr x3, [%[lhs_ptr], #-56]\n"
1975 "smull v13.8h, v1.8b, v7.8b\n"
1976 "ldr x4, [%[lhs_ptr], #-40]\n"
1977 "sadalp v22.4s, v14.8h\n"
1978 "ldr x5, [%[lhs_ptr], #-24]\n"
1979 "smull v14.8h, v2.8b, v7.8b\n"
1980 "ldr x6, [%[lhs_ptr], #-8]\n"
1981 "sadalp v23.4s, v15.8h\n"
1982 "smull v15.8h, v3.8b, v7.8b\n"
1983
1984 // Multiply-accumulate second-half, again into the same
1985 // 16bit local accumulator registers. This is where we
1986 // take advantage of having int8 instead of uint8 and therefore
1987 // being able to accumulate two products into int16.
1988 "smlal2 v8.8h, v0.16b, v6.16b\n"
1989 "smlal2 v9.8h, v1.16b, v6.16b\n"
1990 "smlal2 v10.8h, v2.16b, v6.16b\n"
1991 "ldr x9, [%[rhs_ptr], #-24]\n"
1992 "smlal2 v11.8h, v3.16b, v6.16b\n"
1993 "ldr d6, [%[rhs_ptr], #-32]\n"
1994 "smlal2 v12.8h, v0.16b, v7.16b\n"
1995 "ldr d0, [%[lhs_ptr], #-64]\n"
1996 "smlal2 v13.8h, v1.16b, v7.16b\n"
1997 "ldr d1, [%[lhs_ptr], #-48]\n"
1998 "smlal2 v14.8h, v2.16b, v7.16b\n"
1999 "ins v4.d[1], x7\n"
2000 "smlal2 v15.8h, v3.16b, v7.16b\n"
2001 "ins v5.d[1], x8\n"
2002
2003 "ldr d2, [%[lhs_ptr], #-32]\n"
2004 "ins v0.d[1], x3\n"
2005 "sadalp v24.4s, v8.8h\n"
2006 "ldr d3, [%[lhs_ptr], #-16]\n"
2007 "ins v1.d[1], x4\n"
2008 "smull v8.8h, v0.8b, v4.8b\n"
2009 "ins v2.d[1], x5\n"
2010 "sadalp v25.4s, v9.8h\n"
2011 "ins v3.d[1], x6\n"
2012 "smull v9.8h, v1.8b, v4.8b\n"
2013 "ldr d7, [%[rhs_ptr], #-16]\n"
2014 "sadalp v26.4s, v10.8h\n"
2015 "ldr x10, [%[rhs_ptr], #-8]\n"
2016 "smull v10.8h, v2.8b, v4.8b\n"
2017 "sadalp v27.4s, v11.8h\n"
2018 "smull v11.8h, v3.8b, v4.8b\n"
2019 "sadalp v28.4s, v12.8h\n"
2020 "smull v12.8h, v0.8b, v5.8b\n"
2021 "sadalp v29.4s, v13.8h\n"
2022 "smull v13.8h, v1.8b, v5.8b\n"
2023 "sadalp v30.4s, v14.8h\n"
2024 "smull v14.8h, v2.8b, v5.8b\n"
2025 "sadalp v31.4s, v15.8h\n"
2026 "smull v15.8h, v3.8b, v5.8b\n"
2027
2028 // Multiply-accumulate second-half, again into the same
2029 // 16bit local accumulator registers. This is where we
2030 // take advantage of having int8 instead of uint8 and therefore
2031 // being able to accumulate two products into int16.
2032 "smlal2 v8.8h, v0.16b, v4.16b\n"
2033 "smlal2 v9.8h, v1.16b, v4.16b\n"
2034 "smlal2 v10.8h, v2.16b, v4.16b\n"
2035 "smlal2 v11.8h, v3.16b, v4.16b\n"
2036
2037 "smlal2 v12.8h, v0.16b, v5.16b\n"
2038 "smlal2 v13.8h, v1.16b, v5.16b\n"
2039 "ins v6.d[1], x9\n"
2040 "smlal2 v14.8h, v2.16b, v5.16b\n"
2041 "ins v7.d[1], x10\n"
2042 "smlal2 v15.8h, v3.16b, v5.16b\n"
2043
2044 "blt 2b\n"
2045
2046 "79:\n"
2047
2048 "sadalp v16.4s, v8.8h\n"
2049 "smull v8.8h, v0.8b, v6.8b\n"
2050 "sadalp v17.4s, v9.8h\n"
2051 "smull v9.8h, v1.8b, v6.8b\n"
2052 "sadalp v18.4s, v10.8h\n"
2053 "smull v10.8h, v2.8b, v6.8b\n"
2054 "sadalp v19.4s, v11.8h\n"
2055 "smull v11.8h, v3.8b, v6.8b\n"
2056 "sadalp v20.4s, v12.8h\n"
2057 "smull v12.8h, v0.8b, v7.8b\n"
2058 "sadalp v21.4s, v13.8h\n"
2059 "smull v13.8h, v1.8b, v7.8b\n"
2060 "sadalp v22.4s, v14.8h\n"
2061 "smull v14.8h, v2.8b, v7.8b\n"
2062 "sadalp v23.4s, v15.8h\n"
2063 "smull v15.8h, v3.8b, v7.8b\n"
2064
2065 // Multiply-accumulate second-half, again into the same
2066 // 16bit local accumulator registers. This is where we
2067 // take advantage of having int8 instead of uint8 and therefore
2068 // being able to accumulate two products into int16.
2069 "smlal2 v8.8h, v0.16b, v6.16b\n"
2070 "smlal2 v9.8h, v1.16b, v6.16b\n"
2071 "smlal2 v10.8h, v2.16b, v6.16b\n"
2072 "smlal2 v11.8h, v3.16b, v6.16b\n"
2073
2074 "smlal2 v12.8h, v0.16b, v7.16b\n"
2075 "smlal2 v13.8h, v1.16b, v7.16b\n"
2076 "smlal2 v14.8h, v2.16b, v7.16b\n"
2077 "smlal2 v15.8h, v3.16b, v7.16b\n"
2078
2079 "sadalp v24.4s, v8.8h\n"
2080 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2081 "sadalp v25.4s, v9.8h\n"
2082 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2083 "sadalp v26.4s, v10.8h\n"
2084 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2085 "sadalp v27.4s, v11.8h\n"
2086 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2087 "sadalp v28.4s, v12.8h\n"
2088 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
2089 "sadalp v29.4s, v13.8h\n"
2090 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
2091 "sadalp v30.4s, v14.8h\n"
2092 "sadalp v31.4s, v15.8h\n"
2093
2094 // End of accumulation. The registers v16 -- v31 contain the final
2095 // int32 accumulator values of the current 4x4 destination block.
2096 // We now have to compute the final 8-bit values from these int32
2097 // accumulators, and advance to the next 4x4 block. We intertwine
2098 // these two aspects whenever possible for optimal pipelining, both
2099 // at the data flow level (prefetch data for next block as early as
2100 // possible) and instruction pipelining level (some of the next-block
2101 // work can dual-issue with some of the final work on the current
2102 // block).
2103
2104 // Reduce 32bit accumulators horizontally.
2105 "addp v16.4s, v16.4s, v17.4s\n"
2106 "addp v18.4s, v18.4s, v19.4s\n"
2107 "addp v20.4s, v20.4s, v21.4s\n"
2108 "addp v22.4s, v22.4s, v23.4s\n"
2109 "addp v24.4s, v24.4s, v25.4s\n"
2110 "addp v26.4s, v26.4s, v27.4s\n"
2111 "addp v28.4s, v28.4s, v29.4s\n"
2112 "addp v30.4s, v30.4s, v31.4s\n"
2113
2114 // Reduce 32bit accumulators horizontally, second pass
2115 // (each pass adds pairwise. we need to add 4-wise).
2116 "addp v16.4s, v16.4s, v18.4s\n"
2117 "addp v17.4s, v20.4s, v22.4s\n"
2118 "addp v18.4s, v24.4s, v26.4s\n"
2119 "addp v19.4s, v28.4s, v30.4s\n"
2120
2121 // Logic to advance to the next block in preparation for the next
2122 // iteration of the main loop. For now, we only want to compute
2123 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
2124 // not yet ready to update the values of row and col, as we still need
2125 // the current values for the rest of the work on the current block.
2126
2127 "cmp %w[row], w7\n" // Have we finished the last row?
2128 "bge 4f\n" // If finished last row, go to 4
2129 // Not finished last row: then advance to next row.
2130 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
2131 "b 5f\n"
2132 "4:\n" // Finished last row...
2133 "mov %[lhs_col_ptr], x5\n" // Go back to first row
2134 // Now we need to advance to the next column. If we already
2135 // finished the last column, then in principle we are done, however
2136 // we can't just return here, as we need to allow the end work of the
2137 // current block to complete. The good news is that at this point it
2138 // doesn't matter what data we load for the next column, since
2139 // we will exit from the main loop below before actually storing
2140 // anything computed from that data.
2141 "cmp %w[col], w8\n" // Have we finished the last column?
2142 "bge 5f\n" // If yes, just carry on without updating the column pointer.
2143 // Not finished last column: then advance to next column.
2144 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
2145 "5:\n"
2146
2147 // Set the LHS and RHS data pointers to the start of the columns just
2148 // computed.
2149 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
2150 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
2151
2152 // Load some parameters needed for the end work on current block.
2153 RUY_MAKE_ZERO(v8)
2154 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2155 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
2156 "ins v13.h[4], w4\n" // dst_zero_point
2157 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
2158 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
2159 "dup v9.4s, w3\n" // create prod_zp_depth_vec
2160 "add x5, x4, %x[row], lsl #2\n"
2161 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
2162 "csel x4, x4, x5, eq\n"
2163
2164 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
2165
2166 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
2167 "add x5, x1, %x[row], lsl #2\n"
2168
2169 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
2170 "csel x1, x1, x5, eq\n"
2171
2172 // Load 4 bias values.
2173 "ld1 {v14.4s}, [x1]\n"
2174
2175 // Now that we know what LHS and RHS data the next iteration of the
2176 // main loop will need to load, we start loading the first 32 bytes of
2177 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
2178 // in the rest of the work on the current block.
2179
2180 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
2181 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2182 "add v14.4s, v14.4s, v9.4s\n"
2183 "ldr d0, [%[lhs_ptr], #0]\n"
2184
2185 // Perform the bias-addition (per the above, we have just folded into
2186 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
2187 "add v16.4s, v16.4s, v14.4s\n"
2188 "ldr d1, [%[lhs_ptr], #16]\n"
2189 "add v17.4s, v17.4s, v14.4s\n"
2190 "ldr d2, [%[lhs_ptr], #32]\n"
2191 "add v18.4s, v18.4s, v14.4s\n"
2192 "ldr d3, [%[lhs_ptr], #48]\n"
2193 "add v19.4s, v19.4s, v14.4s\n"
2194 "ldr d4, [%[rhs_ptr], #0]\n"
2195 "ldr d5, [%[rhs_ptr], #16]\n"
2196 "ldr d6, [%[rhs_ptr], #32]\n"
2197 "ldr d7, [%[rhs_ptr], #48]\n"
2198
2199 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
2200 "beq 401f\n"
2201 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
2202 "add x3, x3, %x[col], lsl #2\n"
2203 "ld1 {v14.4s}, [x3]\n"
2204 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
2205 "dup v10.4s, w5\n" // create lhs_zero_point_vec
2206 // Subtract rhs_sums * lhs_zero_point, per
2207 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2208 "mls v16.4s, v10.4s, v14.s[0]\n"
2209 "mls v17.4s, v10.4s, v14.s[1]\n"
2210 "mls v18.4s, v10.4s, v14.s[2]\n"
2211 "mls v19.4s, v10.4s, v14.s[3]\n"
2212 "401:\n"
2213
2214 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
2215 "beq 402f\n"
2216 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
2217 "add x2, x2, %x[row], lsl #2\n"
2218 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
2219 // Load 4 lhs_sums values.
2220 "ld1 {v11.4s}, [x2]\n"
2221 "ins v13.s[1], w5\n" // rhs_zero_point
2222 // Compute lhs_sums * rhs_zero_point.
2223 "mul v11.4s, v11.4s, v13.s[1]\n"
2224 // Subtract lhs_sums * rhs_zero_point, per
2225 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2226 "sub v16.4s, v16.4s, v11.4s\n"
2227 "sub v17.4s, v17.4s, v11.4s\n"
2228 "sub v18.4s, v18.4s, v11.4s\n"
2229 "sub v19.4s, v19.4s, v11.4s\n"
2230
2231 // If the destination is int32, it means the user asks for the raw
2232 // accumulators, no need for us to downquantize the value.
2233 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
2234 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
2235
2236 "402:\n"
2237
2238 // At this point we have computed the final int32 values. Now we
2239 // start down-quantizing them to obtain the final 8bit values from them.
2240
2241 // As part of this down-quantization, our int32 values will be
2242 // multiplied by a multiplier that has a fixed-point component and an
2243 // exponent component.
2244
2245
2246 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
2247 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
2248 "add x5, x1, %x[row], lsl #2\n"
2249 "csel x1, x1, x5, eq\n"
2250
2251 "ld1 {v14.4s}, [x1]\n"
2252
2253 "smax v12.4s, v14.4s, v8.4s\n"
2254 "ldr x1, [%[lhs_ptr], #8]\n"
2255
2256 "sshl v16.4s, v16.4s, v12.4s\n"
2257 "ldr x2, [%[lhs_ptr], #24]\n"
2258 "sshl v17.4s, v17.4s, v12.4s\n"
2259 "ldr x3, [%[lhs_ptr], #40]\n"
2260 "sshl v18.4s, v18.4s, v12.4s\n"
2261 "ldr x4, [%[lhs_ptr], #56]\n"
2262 "sshl v19.4s, v19.4s, v12.4s\n"
2263
2264 "smin v12.4s, v14.4s, v8.4s\n"
2265
2266 // Apply the fixed-point part of the multiplier.
2267 "ins v0.d[1], x1\n"
2268 "ldr x1, [%[rhs_ptr], #8]\n"
2269 "sqrdmulh v16.4s, v16.4s, v15.4s\n"
2270 "ins v1.d[1], x2\n"
2271 "ldr x2, [%[rhs_ptr], #24]\n"
2272 "sqrdmulh v17.4s, v17.4s, v15.4s\n"
2273 "ins v2.d[1], x3\n"
2274 "ldr x3, [%[rhs_ptr], #40]\n"
2275 "sqrdmulh v18.4s, v18.4s, v15.4s\n"
2276 "ins v3.d[1], x4\n"
2277 "ldr x4, [%[rhs_ptr], #56]\n"
2278 "sqrdmulh v19.4s, v19.4s, v15.4s\n"
2279
2280 // We have some rounding division-by-power-of-two to do. This should
2281 // always use "round to nearest". We allow for some
2282 // freedom in how ties are broken, to strike a good compromise of
2283 // performance on given hardware vs. perfect agreement of results
2284 // across hardware.
2285 //
2286 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
2287 // defined tie-breaks to help performance. On NEON, this means that we
2288 // can just use the NEON rounding instructions, such as srshl. They
2289 // happen to be breaking ties upward.
2290 //
2291 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
2292 // break-ties-away-from zero, as described in Appendix B of
2293 // https://arxiv.org/pdf/1712.05877.pdf
2294 // When we wrote that, we thought that that would be better unbiased
2295 // than the NEON upwards tie-breaks, and we had observed some
2296 // improvement on some model. However, that is only more unbiased for
2297 // data centered at zero, which was likely the case in that model,
2298 // but is not always the case. If we wanted something more consistently
2299 // unbiased then we should try breaking ties toward-nearest-even.
2300 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
2301 // Fix up values to be right-shifted, so that the (round to nearest,
2302 // break ties upward) behavior of srshl applied to these fixed-up
2303 // values, produces the same result as the desired (round to nearest,
2304 // break ties away from zero) behavior on the original values.
2305 "and v8.16b, v16.16b, v12.16b\n"
2306 "and v9.16b, v17.16b, v12.16b\n"
2307 "and v14.16b, v18.16b, v12.16b\n"
2308 "and v15.16b, v19.16b, v12.16b\n"
2309 "sshr v8.4s, v8.4s, #31\n"
2310 "sshr v9.4s, v9.4s, #31\n"
2311 "sshr v14.4s, v14.4s, #31\n"
2312 "sshr v15.4s, v15.4s, #31\n"
2313 "sqadd v16.4s, v16.4s, v8.4s\n"
2314 "sqadd v17.4s, v17.4s, v9.4s\n"
2315 "sqadd v18.4s, v18.4s, v14.4s\n"
2316 "sqadd v19.4s, v19.4s, v15.4s\n"
2317 #endif
2318 // At this point we have reduced the problem of correctly implementing
2319 // rounding divide-by-power-of-two, to what the SRSHL instruction can
2320 // do.
2321 "srshl v16.4s, v16.4s, v12.4s\n"
2322 "srshl v17.4s, v17.4s, v12.4s\n"
2323 "srshl v18.4s, v18.4s, v12.4s\n"
2324 "srshl v19.4s, v19.4s, v12.4s\n"
2325
2326 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
2327 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
2328 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
2329 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
2330
2331 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
2332
2333 "ins v4.d[1], x1\n"
2334 "sqxtn v16.4h, v16.4s\n"
2335 "ins v5.d[1], x2\n"
2336 "sqxtn2 v16.8h, v17.4s\n"
2337 "ins v6.d[1], x3\n"
2338 "sqxtn v17.4h, v18.4s\n"
2339 "ins v7.d[1], x4\n"
2340 RUY_MAKE_ZERO(v18)
2341 "sqxtn2 v17.8h, v19.4s\n"
2342
2343 // At this point, v18 -- v31 aren't used anymore for the current block,
2344 // so we can start clearing these accumulators for the next block
2345 // (next iteration of the main loop).
2346 RUY_MAKE_ZERO(v19)
2347
2348 // Add the destination zero point
2349 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2350 "dup v14.8h, v13.h[4]\n"
2351 RUY_MAKE_ZERO(v20)
2352 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2353 "add v16.8h, v16.8h, v14.8h\n"
2354 RUY_MAKE_ZERO(v21)
2355 "add v17.8h, v17.8h, v14.8h\n"
2356 RUY_MAKE_ZERO(v22)
2357
2358 // Cast-and-saturate from int16 to uint8
2359 "sqxtun v16.8b, v16.8h\n"
2360 RUY_MAKE_ZERO(v23)
2361 "sqxtun2 v16.16b, v17.8h\n"
2362 RUY_MAKE_ZERO(v24)
2363
2364 // Load the clamp_min, clamp_max bounds
2365 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2366 RUY_MAKE_ZERO(v25)
2367 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2368 RUY_MAKE_ZERO(v26)
2369 "dup v14.16b, w2\n" // clamp_min
2370 RUY_MAKE_ZERO(v27)
2371 "dup v15.16b, w3\n" // clamp_max
2372 RUY_MAKE_ZERO(v28)
2373
2374 // Apply the clamp_min bound
2375 "umax v16.16b, v16.16b, v14.16b\n"
2376 RUY_MAKE_ZERO(v29)
2377 // Apply the clamp_max bound
2378 "umin v16.16b, v16.16b, v15.16b\n"
2379 RUY_MAKE_ZERO(v30)
2380
2381 // Compute how much of the 4x4 block of destination 8bit values that
2382 // we have computed, fit in the destination matrix. Typically, all of
2383 // it fits, but when the destination matrix shape is not a multiple
2384 // of 4x4, there are some 4x4 blocks along the boundaries that do
2385 // not fit entirely.
2386 "sub w1, %w[dst_rows], %w[row]\n"
2387 RUY_MAKE_ZERO(v31)
2388 "sub w2, %w[dst_cols], %w[col]\n"
2389 "mov w3, #4\n"
2390 "cmp w1, #4\n"
2391 // Compute w1 = how many rows of the 4x4 block fit
2392 "csel w1, w1, w3, le\n"
2393 "cmp w2, #4\n"
2394 // Compute w2 = how many cols of the 4x4 block fit
2395 "csel w2, w2, w3, le\n"
2396
2397 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2398 "cmp w1, w3\n"
2399 "ccmp w2, w3, 0, eq\n"
2400 "mov x4, %[dst_ptr]\n"
2401 // Yes, all of the 4x4 block fits, go to fast path.
2402 "beq 30f\n"
2403 // Not all of the 4x4 block fits.
2404 // Store to dst_tmp_buf
2405 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
2406 // Slow loop copying from dst_tmp_buf to dst.
2407 "mov x3, %[dst_tmp_buf]\n"
2408 "mov w6, #0\n"
2409 "50:\n"
2410 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2411 "mov w5, #0\n"
2412 "51:\n"
2413 "ldrb w7, [x3, w5, uxtw]\n"
2414 "strb w7, [x4, w5, uxtw]\n"
2415 "add w5, w5, #1\n"
2416 "cmp w5, w1\n"
2417 "blt 51b\n"
2418 "add w6, w6, #1\n"
2419 "add x3, x3, #4\n"
2420 "add x4, x4, x11\n"
2421 "cmp w6, w2\n"
2422 "blt 50b\n"
2423 "b 31f\n"
2424 "30:\n"
2425 // Yes, all of the 4x4 block fits.
2426 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2427 "mov x3, x4\n"
2428 "st1 {v16.b}[0], [x3], #1\n"
2429 "add x4, x4, x11\n"
2430 "st1 {v16.b}[1], [x3], #1\n"
2431 "st1 {v16.b}[2], [x3], #1\n"
2432 "st1 {v16.b}[3], [x3], #1\n"
2433 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2434 "mov x3, x4\n"
2435 "st1 {v16.b}[4], [x3], #1\n"
2436 "add x4, x4, x11\n"
2437 "st1 {v16.b}[5], [x3], #1\n"
2438 "st1 {v16.b}[6], [x3], #1\n"
2439 "st1 {v16.b}[7], [x3], #1\n"
2440 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2441 "mov x3, x4\n"
2442 "st1 {v16.b}[8], [x3], #1\n"
2443 "add x4, x4, x11\n"
2444 "st1 {v16.b}[9], [x3], #1\n"
2445 "st1 {v16.b}[10], [x3], #1\n"
2446 "st1 {v16.b}[11], [x3], #1\n"
2447 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2448 "mov x3, x4\n"
2449 "st1 {v16.b}[12], [x3], #1\n"
2450 "add x4, x4, x11\n"
2451 "st1 {v16.b}[13], [x3], #1\n"
2452 "st1 {v16.b}[14], [x3], #1\n"
2453 "st1 {v16.b}[15], [x3], #1\n"
2454 "31:\n"
2455
2456 "add %[dst_ptr], %[dst_ptr], #4\n"
2457
2458 RUY_MAKE_ZERO(v16)
2459 RUY_MAKE_ZERO(v17)
2460
2461 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2462
2463 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2464
2465 "ins v4.d[1], x1\n"
2466 "sqxtn v16.4h, v16.4s\n"
2467 "ins v5.d[1], x2\n"
2468 "sqxtn2 v16.8h, v17.4s\n"
2469 "ins v6.d[1], x3\n"
2470 "sqxtn v17.4h, v18.4s\n"
2471 "ins v7.d[1], x4\n"
2472 RUY_MAKE_ZERO(v18)
2473 "sqxtn2 v17.8h, v19.4s\n"
2474
2475 // At this point, v18 -- v31 aren't used anymore for the current block,
2476 // so we can start clearing these accumulators for the next block
2477 // (next iteration of the main loop).
2478 RUY_MAKE_ZERO(v19)
2479
2480 // Add the destination zero point
2481 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2482 "dup v14.8h, v13.h[4]\n"
2483 RUY_MAKE_ZERO(v20)
2484 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2485 "add v16.8h, v16.8h, v14.8h\n"
2486 RUY_MAKE_ZERO(v21)
2487 "add v17.8h, v17.8h, v14.8h\n"
2488 RUY_MAKE_ZERO(v22)
2489
2490 // Cast-and-saturate from int16 to uint8
2491 "sqxtn v16.8b, v16.8h\n"
2492 RUY_MAKE_ZERO(v23)
2493 "sqxtn2 v16.16b, v17.8h\n"
2494 RUY_MAKE_ZERO(v24)
2495
2496 // Load the clamp_min, clamp_max bounds
2497 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2498 RUY_MAKE_ZERO(v25)
2499 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2500 RUY_MAKE_ZERO(v26)
2501 "dup v14.16b, w2\n" // clamp_min
2502 RUY_MAKE_ZERO(v27)
2503 "dup v15.16b, w3\n" // clamp_max
2504 RUY_MAKE_ZERO(v28)
2505
2506 // Apply the clamp_min bound
2507 "smax v16.16b, v16.16b, v14.16b\n"
2508 RUY_MAKE_ZERO(v29)
2509 // Apply the clamp_max bound
2510 "smin v16.16b, v16.16b, v15.16b\n"
2511 RUY_MAKE_ZERO(v30)
2512
2513 // Compute how much of the 4x4 block of destination 8bit values that
2514 // we have computed, fit in the destination matrix. Typically, all of
2515 // it fits, but when the destination matrix shape is not a multiple
2516 // of 4x4, there are some 4x4 blocks along the boundaries that do
2517 // not fit entirely.
2518 "sub w1, %w[dst_rows], %w[row]\n"
2519 RUY_MAKE_ZERO(v31)
2520 "sub w2, %w[dst_cols], %w[col]\n"
2521 "mov w3, #4\n"
2522 "cmp w1, #4\n"
2523 // Compute w1 = how many rows of the 4x4 block fit
2524 "csel w1, w1, w3, le\n"
2525 "cmp w2, #4\n"
2526 // Compute w2 = how many cols of the 4x4 block fit
2527 "csel w2, w2, w3, le\n"
2528
2529 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2530 "cmp w1, w3\n"
2531 "ccmp w2, w3, 0, eq\n"
2532 "mov x4, %[dst_ptr]\n"
2533 // Yes, all of the 4x4 block fits, go to fast path.
2534 "beq 30f\n"
2535 // Not all of the 4x4 block fits.
2536 // Store to dst_tmp_buf
2537 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
2538 // Slow loop copying from dst_tmp_buf to dst.
2539 "mov x3, %[dst_tmp_buf]\n"
2540 "mov w6, #0\n"
2541 "50:\n"
2542 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2543 "mov w5, #0\n"
2544 "51:\n"
2545 "ldrb w7, [x3, w5, uxtw]\n"
2546 "strb w7, [x4, w5, uxtw]\n"
2547 "add w5, w5, #1\n"
2548 "cmp w5, w1\n"
2549 "blt 51b\n"
2550 "add w6, w6, #1\n"
2551 "add x3, x3, #4\n"
2552 "add x4, x4, x11\n"
2553 "cmp w6, w2\n"
2554 "blt 50b\n"
2555 "b 31f\n"
2556 "30:\n"
2557 // Yes, all of the 4x4 block fits.
2558 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2559 "mov x3, x4\n"
2560 "st1 {v16.b}[0], [x3], #1\n"
2561 "add x4, x4, x11\n"
2562 "st1 {v16.b}[1], [x3], #1\n"
2563 "st1 {v16.b}[2], [x3], #1\n"
2564 "st1 {v16.b}[3], [x3], #1\n"
2565 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2566 "mov x3, x4\n"
2567 "st1 {v16.b}[4], [x3], #1\n"
2568 "add x4, x4, x11\n"
2569 "st1 {v16.b}[5], [x3], #1\n"
2570 "st1 {v16.b}[6], [x3], #1\n"
2571 "st1 {v16.b}[7], [x3], #1\n"
2572 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2573 "mov x3, x4\n"
2574 "st1 {v16.b}[8], [x3], #1\n"
2575 "add x4, x4, x11\n"
2576 "st1 {v16.b}[9], [x3], #1\n"
2577 "st1 {v16.b}[10], [x3], #1\n"
2578 "st1 {v16.b}[11], [x3], #1\n"
2579 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2580 "mov x3, x4\n"
2581 "st1 {v16.b}[12], [x3], #1\n"
2582 "add x4, x4, x11\n"
2583 "st1 {v16.b}[13], [x3], #1\n"
2584 "st1 {v16.b}[14], [x3], #1\n"
2585 "st1 {v16.b}[15], [x3], #1\n"
2586 "31:\n"
2587
2588 "add %[dst_ptr], %[dst_ptr], #4\n"
2589
2590 RUY_MAKE_ZERO(v16)
2591 RUY_MAKE_ZERO(v17)
2592
2593 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2594
2595 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2596
2597 // Add the destination zero point
2598 "dup v14.4h, v13.h[4]\n"
2599 "saddw v16.4s, v16.4s, v14.4h\n"
2600 "saddw v17.4s, v17.4s, v14.4h\n"
2601 "saddw v18.4s, v18.4s, v14.4h\n"
2602 "saddw v19.4s, v19.4s, v14.4h\n"
2603
2604 // Cast-and-saturate from int32 to int16
2605 "ins v4.d[1], x1\n"
2606 "sqxtn v16.4h, v16.4s\n"
2607 "ins v5.d[1], x2\n"
2608 "sqxtn2 v16.8h, v17.4s\n"
2609 "ins v6.d[1], x3\n"
2610 "sqxtn v17.4h, v18.4s\n"
2611 "ins v7.d[1], x4\n"
2612 RUY_MAKE_ZERO(v18)
2613 "sqxtn2 v17.8h, v19.4s\n"
2614
2615 // At this point, v18 -- v31 aren't used anymore for the current block,
2616 // so we can start clearing these accumulators for the next block
2617 // (next iteration of the main loop).
2618 RUY_MAKE_ZERO(v19)
2619
2620 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2621 RUY_MAKE_ZERO(v20)
2622 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2623 RUY_MAKE_ZERO(v21)
2624 RUY_MAKE_ZERO(v22)
2625
2626 RUY_MAKE_ZERO(v23)
2627 RUY_MAKE_ZERO(v24)
2628
2629 // Load the clamp_min, clamp_max bounds
2630 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2631 RUY_MAKE_ZERO(v25)
2632 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2633 RUY_MAKE_ZERO(v26)
2634 "dup v14.8h, w2\n" // clamp_min
2635 RUY_MAKE_ZERO(v27)
2636 "dup v15.8h, w3\n" // clamp_max
2637 RUY_MAKE_ZERO(v28)
2638
2639 // Apply the clamp_min bound
2640 "smax v16.8h, v16.8h, v14.8h\n"
2641 "smax v17.8h, v17.8h, v14.8h\n"
2642 RUY_MAKE_ZERO(v29)
2643 // Apply the clamp_max bound
2644 "smin v16.8h, v16.8h, v15.8h\n"
2645 "smin v17.8h, v17.8h, v15.8h\n"
2646 RUY_MAKE_ZERO(v30)
2647
2648 // Compute how much of the 4x4 block of destination 8bit values that
2649 // we have computed, fit in the destination matrix. Typically, all of
2650 // it fits, but when the destination matrix shape is not a multiple
2651 // of 4x4, there are some 4x4 blocks along the boundaries that do
2652 // not fit entirely.
2653 "sub w1, %w[dst_rows], %w[row]\n"
2654 RUY_MAKE_ZERO(v31)
2655 "sub w2, %w[dst_cols], %w[col]\n"
2656 "mov w3, #4\n"
2657 "cmp w1, #4\n"
2658 // Compute w1 = how many rows of the 4x4 block fit
2659 "csel w1, w1, w3, le\n"
2660 "cmp w2, #4\n"
2661 // Compute w2 = how many cols of the 4x4 block fit
2662 "csel w2, w2, w3, le\n"
2663
2664 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
2665 "cmp w1, w3\n"
2666 "ccmp w2, w3, 0, eq\n"
2667 "mov x4, %[dst_ptr]\n"
2668 // Yes, all of the 4x4 block fits, go to fast path.
2669 "beq 30f\n"
2670 // Not all of the 4x4 block fits.
2671 // Store to dst_tmp_buf
2672 "str q16, [%[dst_tmp_buf], #0]\n"
2673 "str q17, [%[dst_tmp_buf], #16]\n"
2674 // Slow loop copying from dst_tmp_buf to dst.
2675 "mov x3, %[dst_tmp_buf]\n"
2676 "mov w6, #0\n"
2677 "50:\n"
2678 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2679 "mov w5, #0\n"
2680 "51:\n"
2681 "ldrh w7, [x3, x5, lsl #1]\n"
2682 "strh w7, [x4, x5, lsl #1]\n"
2683 "add w5, w5, #1\n"
2684 "cmp w5, w1\n"
2685 "blt 51b\n"
2686 "add w6, w6, #1\n"
2687 "add x3, x3, #8\n"
2688 "add x4, x4, x11\n"
2689 "cmp w6, w2\n"
2690 "blt 50b\n"
2691 "b 31f\n"
2692 "30:\n"
2693 // Yes, all of the 4x4 block fits.
2694 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2695 "mov x3, x4\n"
2696 "st1 {v16.h}[0], [x3], #2\n"
2697 "add x4, x4, x11\n"
2698 "st1 {v16.h}[1], [x3], #2\n"
2699 "st1 {v16.h}[2], [x3], #2\n"
2700 "st1 {v16.h}[3], [x3], #2\n"
2701 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2702 "mov x3, x4\n"
2703 "st1 {v16.h}[4], [x3], #2\n"
2704 "add x4, x4, x11\n"
2705 "st1 {v16.h}[5], [x3], #2\n"
2706 "st1 {v16.h}[6], [x3], #2\n"
2707 "st1 {v16.h}[7], [x3], #2\n"
2708 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2709 "mov x3, x4\n"
2710 "st1 {v17.h}[0], [x3], #2\n"
2711 "add x4, x4, x11\n"
2712 "st1 {v17.h}[1], [x3], #2\n"
2713 "st1 {v17.h}[2], [x3], #2\n"
2714 "st1 {v17.h}[3], [x3], #2\n"
2715 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2716 "mov x3, x4\n"
2717 "st1 {v17.h}[4], [x3], #2\n"
2718 "add x4, x4, x11\n"
2719 "st1 {v17.h}[5], [x3], #2\n"
2720 "st1 {v17.h}[6], [x3], #2\n"
2721 "st1 {v17.h}[7], [x3], #2\n"
2722 "31:\n"
2723
2724 "add %[dst_ptr], %[dst_ptr], #8\n"
2725
2726 RUY_MAKE_ZERO(v16)
2727 RUY_MAKE_ZERO(v17)
2728
2729 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2730
2731 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2732
2733 "ldr x1, [%[lhs_ptr], #8]\n"
2734 "ldr x2, [%[lhs_ptr], #24]\n"
2735 "ldr x3, [%[lhs_ptr], #40]\n"
2736 "ldr x4, [%[lhs_ptr], #56]\n"
2737
2738 "ins v0.d[1], x1\n"
2739 "ldr x1, [%[rhs_ptr], #8]\n"
2740 "ins v1.d[1], x2\n"
2741 "ldr x2, [%[rhs_ptr], #24]\n"
2742 "ins v2.d[1], x3\n"
2743 "ldr x3, [%[rhs_ptr], #40]\n"
2744 "ins v3.d[1], x4\n"
2745 "ldr x4, [%[rhs_ptr], #56]\n"
2746 "ins v4.d[1], x1\n"
2747 "ins v5.d[1], x2\n"
2748 "ins v6.d[1], x3\n"
2749 "ins v7.d[1], x4\n"
2750
2751 // Since the store type is the same as the accum type, no need for
2752 // downcast. There's also no need for clamp by min/max.
2753
2754 // At this point, v20 -- v31 aren't used anymore for the current block,
2755 // so we can start clearing these accumulators for the next block
2756 // (next iteration of the main loop).
2757
2758 RUY_MAKE_ZERO(v20)
2759 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2760 RUY_MAKE_ZERO(v21)
2761 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2762 RUY_MAKE_ZERO(v22)
2763
2764 RUY_MAKE_ZERO(v23)
2765 RUY_MAKE_ZERO(v24)
2766 RUY_MAKE_ZERO(v25)
2767 RUY_MAKE_ZERO(v26)
2768 RUY_MAKE_ZERO(v27)
2769 RUY_MAKE_ZERO(v28)
2770 RUY_MAKE_ZERO(v29)
2771 RUY_MAKE_ZERO(v30)
2772
2773 // Compute how much of the 4x4 block of destination 8bit values that
2774 // we have computed, fit in the destination matrix. Typically, all of
2775 // it fits, but when the destination matrix shape is not a multiple
2776 // of 4x4, there are some 4x4 blocks along the boundaries that do
2777 // not fit entirely.
2778 "sub w1, %w[dst_rows], %w[row]\n"
2779 RUY_MAKE_ZERO(v31)
2780 "sub w2, %w[dst_cols], %w[col]\n"
2781 "mov w3, #4\n"
2782 "cmp w1, #4\n"
2783 // Compute w1 = how many rows of the 4x4 block fit
2784 "csel w1, w1, w3, le\n"
2785 "cmp w2, #4\n"
2786 // Compute w2 = how many cols of the 4x4 block fit
2787 "csel w2, w2, w3, le\n"
2788
2789 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2790 "cmp w1, w3\n"
2791 "ccmp w2, w3, 0, eq\n"
2792 "mov x4, %[dst_ptr]\n"
2793 // Yes, all of the 4x4 block fits, go to fast path.
2794 "beq 30f\n"
2795 // Not all of the 4x4 block fits.
2796 // Store to dst_tmp_buf
2797 "str q16, [%[dst_tmp_buf], #0]\n"
2798 "str q17, [%[dst_tmp_buf], #16]\n"
2799 "str q18, [%[dst_tmp_buf], #32]\n"
2800 "str q19, [%[dst_tmp_buf], #48]\n"
2801 // Slow loop copying from dst_tmp_buf to dst.
2802 "mov x3, %[dst_tmp_buf]\n"
2803 "mov w6, #0\n"
2804 "50:\n"
2805 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2806 "mov w5, #0\n"
2807 "51:\n"
2808 "ldr w7, [x3, x5, lsl #2]\n"
2809 "str w7, [x4, x5, lsl #2]\n"
2810 "add w5, w5, #1\n"
2811 "cmp w5, w1\n"
2812 "blt 51b\n"
2813 "add w6, w6, #1\n"
2814 "add x3, x3, #16\n"
2815 "add x4, x4, x11\n"
2816 "cmp w6, w2\n"
2817 "blt 50b\n"
2818 "b 31f\n"
2819 "30:\n"
2820 // Yes, all of the 4x4 block fits.
2821 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2822 "mov x3, x4\n"
2823 "st1 {v16.s}[0], [x3], #4\n"
2824 "add x4, x4, x11\n"
2825 "st1 {v16.s}[1], [x3], #4\n"
2826 "st1 {v16.s}[2], [x3], #4\n"
2827 "st1 {v16.s}[3], [x3], #4\n"
2828 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2829 "mov x3, x4\n"
2830 "st1 {v17.s}[0], [x3], #4\n"
2831 "add x4, x4, x11\n"
2832 "st1 {v17.s}[1], [x3], #4\n"
2833 "st1 {v17.s}[2], [x3], #4\n"
2834 "st1 {v17.s}[3], [x3], #4\n"
2835 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2836 "mov x3, x4\n"
2837 "st1 {v18.s}[0], [x3], #4\n"
2838 "add x4, x4, x11\n"
2839 "st1 {v18.s}[1], [x3], #4\n"
2840 "st1 {v18.s}[2], [x3], #4\n"
2841 "st1 {v18.s}[3], [x3], #4\n"
2842 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2843 "mov x3, x4\n"
2844 "st1 {v19.s}[0], [x3], #4\n"
2845 "add x4, x4, x11\n"
2846 "st1 {v19.s}[1], [x3], #4\n"
2847 "st1 {v19.s}[2], [x3], #4\n"
2848 "st1 {v19.s}[3], [x3], #4\n"
2849 "31:\n"
2850
2851 "add %[dst_ptr], %[dst_ptr], #16\n"
2852
2853 RUY_MAKE_ZERO(v16)
2854 RUY_MAKE_ZERO(v17)
2855 RUY_MAKE_ZERO(v18)
2856 RUY_MAKE_ZERO(v19)
2857
2858 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2859
2860 // For the next block: perform the first few multiply-adds on the data
2861 // that we have already loaded.
2862 "smull v8.8h, v0.8b, v4.8b\n"
2863 "smull v9.8h, v1.8b, v4.8b\n"
2864 "smull v10.8h, v2.8b, v4.8b\n"
2865 // Reload some params --- we had used x5 -- x7 for a few other things
2866 // since the last time we had loaded them.
2867 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2868 "smull v11.8h, v3.8b, v4.8b\n"
2869 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2870 "smull v12.8h, v0.8b, v5.8b\n"
2871 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2872 "smull v13.8h, v1.8b, v5.8b\n"
2873 "smull v14.8h, v2.8b, v5.8b\n"
2874 "smull v15.8h, v3.8b, v5.8b\n"
2875 // Move to the next block of the destination matrix, for the next iter
2876 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
2877 // been updated earlier.
2878 // Have we reached the end row?
2879 "cmp %w[row], w7\n"
2880 "smlal2 v8.8h, v0.16b, v4.16b\n"
2881 "smlal2 v9.8h, v1.16b, v4.16b\n"
2882 "smlal2 v10.8h, v2.16b, v4.16b\n"
2883 "smlal2 v11.8h, v3.16b, v4.16b\n"
2884 "smlal2 v12.8h, v0.16b, v5.16b\n"
2885 "smlal2 v13.8h, v1.16b, v5.16b\n"
2886 "smlal2 v14.8h, v2.16b, v5.16b\n"
2887 "smlal2 v15.8h, v3.16b, v5.16b\n"
2888
2889
2890 "beq 20f\n" // yes, end row.
2891 // Not end row. Move to the next row.
2892 "add %w[row], %w[row], #4\n"
2893 "b 21f\n"
2894 "20:\n"
2895 // Was already at end row.
2896 "mov %w[row], w6\n" // Move back to first row.
2897 "add %w[col], %w[col], #4\n" // Move to the next column.
2898 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
2899 "mov %[dst_ptr], %[dst_col_ptr]\n"
2900 "21:\n"
2901
2902 // Main loop exit condition: have we hit the end column?
2903 "cmp %w[col], w8\n"
2904
2905 // w1 is the number of levels of depth that we have already loaded
2906 // LHS and RHS data for. Corresponding to the initial ld1 instructions
2907 // above, this is currently 4.
2908 "mov w1, #16\n"
2909
2910 "ble 1b\n"
2911
2912 // clang-format on
2913
2914 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
2915 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2916 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
2917 : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows),
2918 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
2919 [dst_type_id] "r"(params.dst_type_id)
2920 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
2921 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
2922 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
2923 "v26", "v27", "v28", "v29", "v30", "v31");
2924 }
2925
2926 // Kernel taking advantage of the optional dotprod instruction.
2927 // This is very similar to (and directly inspired by) this gemmlowp kernel
2928 // which was contributed by David Mansell at ARM:
2929 // NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
2930 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
2931 //
2932 // Besides the ruy-ification, the main difference here is that we use a 8x8
2933 // instead of 12x8 width, so as to stick to power-of-two widths. This slightly
2934 // narrower kernel layout is still wide enough to achieve high performance
2935 // although we haven't actually performed a real comparison to know exactly
2936 // how this compares to ARM's aforementioned kernel.
2937 //
2938 // Relevant target CPUs for this kernel include ARM Cortex-A76,
2939 // since these are 64-bit, out-of-order and with dotprod support.
Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8,8> & params)2940 void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) {
2941 profiler::ScopeLabel label(
2942 "Kernel (kNeonDotprod, optimized for out-of-order cores)");
2943
2944 CheckOffsetsInKernelParams8bit(params);
2945
2946 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
2947 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
2948 const std::int8_t* lhs_ptr = lhs_col_ptr;
2949 const std::int8_t* rhs_ptr = rhs_col_ptr;
2950 void* dst_col_ptr = params.dst_base_ptr;
2951 void* dst_ptr = dst_col_ptr;
2952 int row = params.start_row;
2953 int col = params.start_col;
2954
2955 // The asm kernel below has the following NEON register allocation:
2956 //
2957 // v16 -- v31 are int32 accumulators.
2958 // During accumulation, v0 -- v15 are used to load int8 data from LHS and
2959 // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
2960 // v3 are used to load a 4x8 block of RHS, like this:
2961 //
2962 // int8 RHS 4x8 block
2963 // /-----------------------------------------\
2964 // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
2965 // | ... ... |
2966 // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
2967 // \-----------------------------------------/
2968 // int8 LHS 8x4 block
2969 // /---------------------\ /-----------------------------------------\
2970 // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
2971 // | ... ... | | ... ... |
2972 // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
2973 // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
2974 // | ... ... | | ... ... |
2975 // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
2976 // \---------------------/ \-----------------------------------------/
2977 // int32 accumulators 8x8 block
2978 //
2979 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
2980 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
2981 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
2982 //
2983 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
2984 // unused, and v8 -- v15 are used for loading parameters used for the
2985 // post-accumulation part of the kernel.
2986 asm volatile(
2987 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
2988
2989 // clang-format off
2990
2991 // Load some parameters into registers.
2992 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2993 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2994 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2995 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2996 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
2997 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
2998 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2999 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
3000
3001 // Load the first 32 bytes of LHS and RHS data.
3002 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3003 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3004 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3005 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3006
3007 // Clear accumulators.
3008 RUY_MAKE_ZERO(v16)
3009 RUY_MAKE_ZERO(v17)
3010 RUY_MAKE_ZERO(v18)
3011 RUY_MAKE_ZERO(v19)
3012 RUY_MAKE_ZERO(v20)
3013 RUY_MAKE_ZERO(v21)
3014 RUY_MAKE_ZERO(v22)
3015 RUY_MAKE_ZERO(v23)
3016 RUY_MAKE_ZERO(v24)
3017 RUY_MAKE_ZERO(v25)
3018 RUY_MAKE_ZERO(v26)
3019 RUY_MAKE_ZERO(v27)
3020 RUY_MAKE_ZERO(v28)
3021 RUY_MAKE_ZERO(v29)
3022 RUY_MAKE_ZERO(v30)
3023 RUY_MAKE_ZERO(v31)
3024
3025 // w1 is the number of levels of depth that we have already loaded
3026 // LHS and RHS data for. Corresponding to the initial ld1 instructions
3027 // above, this is currently 4.
3028 "mov w1, #4\n"
3029
3030 // Perform the first few multiply-adds on the data that we have already
3031 // loaded.
3032 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3033 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3034 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3035 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3036
3037 // Main loop of the whole GEMM, over rows and columns of the
3038 // destination matrix.
3039 "1:\n"
3040
3041 // Optional, maximally-streaming, partial-unrolling (4x unrolled)
3042 // optimization of the kernel inner loop (over depth). For more
3043 // comments, see the non-unrolled loop below after the #endif.
3044 #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
3045 "cmp w12, #32\n"
3046 "blt 78f\n"
3047
3048 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
3049 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
3050 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
3051 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
3052 "ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
3053 "ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
3054 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
3055 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
3056 "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
3057 "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
3058 "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
3059 "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
3060 "mov w1, #16\n"
3061
3062 "and w3, w12, #-16\n"
3063 "81:\n"
3064 "add w1, w1, #16\n"
3065
3066 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3067 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3068 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3069 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3070 "ldr q0, [%[lhs_ptr], #0]\n"
3071 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3072 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3073 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3074 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3075 "ldr q2, [%[rhs_ptr], #0]\n"
3076 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3077 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3078 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3079 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3080 "ldr q1, [%[lhs_ptr], #16]\n"
3081
3082 ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
3083 ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
3084 "ldr q3, [%[rhs_ptr], #16]\n"
3085 ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
3086 ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
3087 ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
3088 ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
3089 ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
3090 ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
3091 ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
3092 ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
3093 ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
3094 ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
3095 "ldr q5, [%[lhs_ptr], #48]\n"
3096 ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
3097 ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
3098 "ldr q7, [%[rhs_ptr], #48]\n"
3099 ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
3100 ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
3101 "ldr q4, [%[lhs_ptr], #32]\n"
3102
3103 ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
3104 ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
3105 "ldr q6, [%[rhs_ptr], #32]\n"
3106 ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
3107 ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
3108 ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
3109 ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
3110 ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
3111 ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
3112 ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
3113 ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
3114 ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
3115 ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
3116 "ldr q9, [%[lhs_ptr], #80]\n"
3117 ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
3118 ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
3119 "ldr q11, [%[rhs_ptr], #80]\n"
3120 ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
3121 ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
3122 "ldr q8, [%[lhs_ptr], #64]\n"
3123
3124 ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
3125 ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
3126 "ldr q10, [%[rhs_ptr], #64]\n"
3127 ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
3128 ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
3129 "add %[lhs_ptr], %[lhs_ptr], #128\n"
3130 ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
3131 ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
3132 "add %[rhs_ptr], %[rhs_ptr], #128\n"
3133 ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
3134 ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
3135 ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
3136 ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
3137 "cmp w1, w3\n"
3138 ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
3139 ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
3140 "ldr q13, [%[lhs_ptr], #-16]\n"
3141 ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
3142 ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
3143 "ldr q15, [%[rhs_ptr], #-16]\n"
3144 ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
3145 ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
3146 "ldr q12, [%[lhs_ptr], #-32]\n"
3147
3148 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3149 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3150 "ldr q14, [%[rhs_ptr], #-32]\n"
3151 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3152 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3153
3154 "blt 81b\n"
3155
3156 ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
3157 ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
3158 ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
3159 ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
3160 ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
3161 ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
3162 ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
3163 ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
3164 ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
3165 ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
3166 ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
3167 ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
3168 ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
3169 ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
3170 ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
3171 ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
3172
3173 ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
3174 ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
3175 ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
3176 ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
3177 ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
3178 ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
3179 ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
3180 ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
3181 ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
3182 ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
3183 ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
3184 ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
3185 ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
3186 ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
3187 ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
3188 ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
3189
3190 ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
3191 ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
3192 ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
3193 ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
3194 ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
3195 ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
3196 ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
3197 ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
3198 ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
3199 ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
3200 ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
3201 ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
3202 ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
3203 ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
3204 ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
3205 ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
3206
3207 "78:\n"
3208
3209 #endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
3210
3211 // Ordinary kernel inner loop (over depth), the simpler loop that the
3212 // above was an equivalent 4x-partially-unrolled version of.
3213
3214 // Reminder - w1 is how many levels of depth we have already loaded
3215 // data for, w12 is the total depth.
3216 "cmp w1, w12\n"
3217 "beq 79f\n"
3218
3219 "2:\n"
3220
3221 // Because of the data that we have already loaded, we can start the
3222 // loop body right away with some multiply-adds.
3223 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3224 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3225 // Each iteration of this loop advances by 4 levels of depth.
3226 "add w1, w1, #4\n"
3227 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3228 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3229 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3230 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3231 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3232 // Loop termination condition.
3233 "cmp w1, w12\n"
3234 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3235 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3236 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3237 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3238 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3239 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3240 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3241 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3242 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3243 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3244 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3245 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3246 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3247
3248 "blt 2b\n"
3249
3250 "79:\n"
3251 // End of the inner loop on depth. Now perform the remaining
3252 // multiply-adds of the last 4 levels of depth, for which the LHS
3253 // and RHS data is already loaded.
3254
3255 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3256 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3257 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3258 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3259 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3260 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3261 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3262 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3263 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3264 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3265 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3266 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3267
3268 // End of accumulation. The registers v16 -- v31 contain the final
3269 // int32 accumulator values of the current 8x8 destination block.
3270 // We now have to compute the final 8-bit values from these int32
3271 // accumulators, and advance to the next 8x8 block. We intertwine
3272 // these two aspects whenever possible for optimal pipelining, both
3273 // at the data flow level (prefetch data for next block as early as
3274 // possible) and instruction pipelining level (some of the next-block
3275 // work can dual-issue with some of the final work on the current
3276 // block).
3277
3278 // Logic to advance to the next block in preparation for the next
3279 // iteration of the main loop. For now, we only want to compute
3280 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
3281 // not yet ready to update the values of row and col, as we still need
3282 // the current values for the rest of the work on the current block.
3283
3284 "cmp %w[row], w7\n" // Have we finished the last row?
3285 "bge 4f\n" // If finished last row, go to 4
3286 // Not finished last row: then advance to next row.
3287 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
3288 "b 5f\n"
3289 "4:\n" // Finished last row...
3290 "mov %[lhs_col_ptr], x5\n" // Go back to first row
3291 // Now we need to advance to the next column. If we already
3292 // finished the last column, then in principle we are done, however
3293 // we can't just return here, as we need to allow the end work of the
3294 // current block to complete. The good news is that at this point it
3295 // doesn't matter what data we load for the next column, since
3296 // we will exit from the main loop below before actually storing
3297 // anything computed from that data.
3298 "cmp %w[col], w8\n" // Have we finished the last column?
3299 "bge 5f\n" // If yes, just carry on without updating the column pointer.
3300 // Not finished last column: then advance to next column.
3301 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
3302 "5:\n"
3303
3304 // Set the LHS and RHS data pointers to the start of the columns just
3305 // computed.
3306 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
3307 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
3308
3309 // Load some parameters needed for the end work on current block.
3310 RUY_MAKE_ZERO(v8)
3311 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
3312 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
3313 "ins v13.h[4], w4\n" // dst_zero_point
3314 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
3315 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
3316 "dup v9.4s, w3\n" // create prod_zp_depth_vec
3317 "add x5, x4, %x[row], lsl #2\n"
3318 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
3319 "csel x4, x4, x5, eq\n"
3320
3321 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
3322 "add x5, x1, %x[row], lsl #2\n"
3323
3324 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
3325 "csel x1, x1, x5, eq\n"
3326
3327 // Load 8 bias values.
3328 "ld1 {v14.4s}, [x1], #16\n"
3329 "ld1 {v15.4s}, [x1]\n"
3330
3331 // Now that we know what LHS and RHS data the next iteration of the
3332 // main loop will need to load, we start loading the first 32 bytes of
3333 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
3334 // in the rest of the work on the current block.
3335 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3336 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3337 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3338 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3339
3340 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
3341 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3342 "add v14.4s, v14.4s, v9.4s\n"
3343 "add v15.4s, v15.4s, v9.4s\n"
3344
3345 // Perform the bias-addition (per the above, we have just folded into
3346 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
3347 "add v16.4s, v16.4s, v14.4s\n"
3348 "add v17.4s, v17.4s, v15.4s\n"
3349 "add v18.4s, v18.4s, v14.4s\n"
3350 "add v19.4s, v19.4s, v15.4s\n"
3351 "add v20.4s, v20.4s, v14.4s\n"
3352 "add v21.4s, v21.4s, v15.4s\n"
3353 "add v22.4s, v22.4s, v14.4s\n"
3354 "add v23.4s, v23.4s, v15.4s\n"
3355 "add v24.4s, v24.4s, v14.4s\n"
3356 "add v25.4s, v25.4s, v15.4s\n"
3357 "add v26.4s, v26.4s, v14.4s\n"
3358 "add v27.4s, v27.4s, v15.4s\n"
3359 "add v28.4s, v28.4s, v14.4s\n"
3360 "add v29.4s, v29.4s, v15.4s\n"
3361 "add v30.4s, v30.4s, v14.4s\n"
3362 "add v31.4s, v31.4s, v15.4s\n"
3363
3364 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
3365 "beq 401f\n"
3366 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
3367 "add x3, x3, %x[col], lsl #2\n"
3368 "ld1 {v14.4s}, [x3], #16\n"
3369 "ld1 {v15.4s}, [x3]\n"
3370 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
3371 "dup v10.4s, w5\n" // create lhs_zero_point_vec
3372 // Subtract rhs_sums * lhs_zero_point, per
3373 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3374 "mls v16.4s, v10.4s, v14.s[0]\n"
3375 "mls v17.4s, v10.4s, v14.s[0]\n"
3376 "mls v18.4s, v10.4s, v14.s[1]\n"
3377 "mls v19.4s, v10.4s, v14.s[1]\n"
3378 "mls v20.4s, v10.4s, v14.s[2]\n"
3379 "mls v21.4s, v10.4s, v14.s[2]\n"
3380 "mls v22.4s, v10.4s, v14.s[3]\n"
3381 "mls v23.4s, v10.4s, v14.s[3]\n"
3382 "mls v24.4s, v10.4s, v15.s[0]\n"
3383 "mls v25.4s, v10.4s, v15.s[0]\n"
3384 "mls v26.4s, v10.4s, v15.s[1]\n"
3385 "mls v27.4s, v10.4s, v15.s[1]\n"
3386 "mls v28.4s, v10.4s, v15.s[2]\n"
3387 "mls v29.4s, v10.4s, v15.s[2]\n"
3388 "mls v30.4s, v10.4s, v15.s[3]\n"
3389 "mls v31.4s, v10.4s, v15.s[3]\n"
3390 "401:\n"
3391
3392 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
3393 "beq 402f\n"
3394 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
3395 "add x2, x2, %x[row], lsl #2\n"
3396 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
3397 // Load 4 lhs_sums values.
3398 "ld1 {v11.4s}, [x2], #16\n"
3399 "ld1 {v12.4s}, [x2]\n"
3400 "ins v13.s[1], w5\n" // rhs_zero_point
3401 // Compute lhs_sums * rhs_zero_point.
3402 "mul v11.4s, v11.4s, v13.s[1]\n"
3403 "mul v12.4s, v12.4s, v13.s[1]\n"
3404 // Subtract lhs_sums * rhs_zero_point, per
3405 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3406 "sub v16.4s, v16.4s, v11.4s\n"
3407 "sub v17.4s, v17.4s, v12.4s\n"
3408 "sub v18.4s, v18.4s, v11.4s\n"
3409 "sub v19.4s, v19.4s, v12.4s\n"
3410 "sub v20.4s, v20.4s, v11.4s\n"
3411 "sub v21.4s, v21.4s, v12.4s\n"
3412 "sub v22.4s, v22.4s, v11.4s\n"
3413 "sub v23.4s, v23.4s, v12.4s\n"
3414 "sub v24.4s, v24.4s, v11.4s\n"
3415 "sub v25.4s, v25.4s, v12.4s\n"
3416 "sub v26.4s, v26.4s, v11.4s\n"
3417 "sub v27.4s, v27.4s, v12.4s\n"
3418 "sub v28.4s, v28.4s, v11.4s\n"
3419 "sub v29.4s, v29.4s, v12.4s\n"
3420 "sub v30.4s, v30.4s, v11.4s\n"
3421 "sub v31.4s, v31.4s, v12.4s\n"
3422
3423 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
3424 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
3425
3426 "402:\n"
3427
3428 // At this point we have computed the final int32 values. Now we
3429 // start down-quantizing them to obtain the final 8bit values from them.
3430
3431 // As part of this down-quantization, our int32 values will be
3432 // multiplied by a multiplier that has a fixed-point component and an
3433 // exponent component.
3434
3435 //Load the exponent part of the multiplier.
3436 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
3437 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
3438 "add x5, x1, %x[row], lsl #2\n"
3439 "csel x1, x1, x5, eq\n"
3440
3441 "ldr q9, [x1]\n"
3442 "ldr q10, [x1, #16]\n"
3443
3444 "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
3445 "beq 403f\n"
3446 "smax v11.4s, v9.4s, v8.4s\n"
3447 "smax v12.4s, v10.4s, v8.4s\n"
3448 "sshl v16.4s, v16.4s, v11.4s\n"
3449 "sshl v17.4s, v17.4s, v12.4s\n"
3450 "sshl v18.4s, v18.4s, v11.4s\n"
3451 "sshl v19.4s, v19.4s, v12.4s\n"
3452 "sshl v20.4s, v20.4s, v11.4s\n"
3453 "sshl v21.4s, v21.4s, v12.4s\n"
3454 "sshl v22.4s, v22.4s, v11.4s\n"
3455 "sshl v23.4s, v23.4s, v12.4s\n"
3456 "sshl v24.4s, v24.4s, v11.4s\n"
3457 "sshl v25.4s, v25.4s, v12.4s\n"
3458 "sshl v26.4s, v26.4s, v11.4s\n"
3459 "sshl v27.4s, v27.4s, v12.4s\n"
3460 "sshl v28.4s, v28.4s, v11.4s\n"
3461 "sshl v29.4s, v29.4s, v12.4s\n"
3462 "sshl v30.4s, v30.4s, v11.4s\n"
3463 "sshl v31.4s, v31.4s, v12.4s\n"
3464 "403:\n"
3465
3466 "ldr q14, [x4]\n" // multiplier_fixedpoint
3467 "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
3468
3469 "smin v11.4s, v9.4s, v8.4s\n"
3470 "smin v12.4s, v10.4s, v8.4s\n"
3471
3472 // Apply the fixed-point part of the multiplier.
3473 "sqrdmulh v16.4s, v16.4s, v14.4s\n"
3474 "sqrdmulh v17.4s, v17.4s, v15.4s\n"
3475 "sqrdmulh v18.4s, v18.4s, v14.4s\n"
3476 "sqrdmulh v19.4s, v19.4s, v15.4s\n"
3477 "sqrdmulh v20.4s, v20.4s, v14.4s\n"
3478 "sqrdmulh v21.4s, v21.4s, v15.4s\n"
3479 "sqrdmulh v22.4s, v22.4s, v14.4s\n"
3480 "sqrdmulh v23.4s, v23.4s, v15.4s\n"
3481 "sqrdmulh v24.4s, v24.4s, v14.4s\n"
3482 "sqrdmulh v25.4s, v25.4s, v15.4s\n"
3483 "sqrdmulh v26.4s, v26.4s, v14.4s\n"
3484 "sqrdmulh v27.4s, v27.4s, v15.4s\n"
3485 "sqrdmulh v28.4s, v28.4s, v14.4s\n"
3486 "sqrdmulh v29.4s, v29.4s, v15.4s\n"
3487 "sqrdmulh v30.4s, v30.4s, v14.4s\n"
3488 "sqrdmulh v31.4s, v31.4s, v15.4s\n"
3489
3490 // We have some rounding division-by-power-of-two to do. This should
3491 // always use "round to nearest". We allow for some
3492 // freedom in how ties are broken, to strike a good compromise of
3493 // performance on given hardware vs. perfect agreement of results
3494 // across hardware.
3495 //
3496 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
3497 // defined tie-breaks to help performance. On NEON, this means that we
3498 // can just use the NEON rounding instructions, such as srshl. They
3499 // happen to be breaking ties upward.
3500 //
3501 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
3502 // break-ties-away-from zero, as described in Appendix B of
3503 // https://arxiv.org/pdf/1712.05877.pdf
3504 // When we wrote that, we thought that that would be better unbiased
3505 // than the NEON upwards tie-breaks, and we had observed some
3506 // improvement on some model. However, that is only more unbiased for
3507 // data centered at zero, which was likely the case in that model,
3508 // but is not always the case. If we wanted something more consistently
3509 // unbiased then we should try breaking ties toward-nearest-even.
3510 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
3511 // Fix up values to be right-shifted, so that the (round to nearest,
3512 // break ties upward) behavior of srshl applied to these fixed-up
3513 // values, produces the same result as the desired (round to nearest,
3514 // break ties away from zero) behavior on the original values.
3515 "and v8.16b, v16.16b, v11.16b\n"
3516 "and v9.16b, v17.16b, v12.16b\n"
3517 "and v14.16b, v18.16b, v11.16b\n"
3518 "and v15.16b, v19.16b, v12.16b\n"
3519 "sshr v8.4s, v8.4s, #31\n"
3520 "sshr v9.4s, v9.4s, #31\n"
3521 "sshr v14.4s, v14.4s, #31\n"
3522 "sshr v15.4s, v15.4s, #31\n"
3523 "sqadd v16.4s, v16.4s, v8.4s\n"
3524 "sqadd v17.4s, v17.4s, v9.4s\n"
3525 "sqadd v18.4s, v18.4s, v14.4s\n"
3526 "sqadd v19.4s, v19.4s, v15.4s\n"
3527 "and v8.16b, v20.16b, v11.16b\n"
3528 "and v9.16b, v21.16b, v12.16b\n"
3529 "and v14.16b, v22.16b, v11.16b\n"
3530 "and v15.16b, v23.16b, v12.16b\n"
3531 "sshr v8.4s, v8.4s, #31\n"
3532 "sshr v9.4s, v9.4s, #31\n"
3533 "sshr v14.4s, v14.4s, #31\n"
3534 "sshr v15.4s, v15.4s, #31\n"
3535 "sqadd v20.4s, v20.4s, v8.4s\n"
3536 "sqadd v21.4s, v21.4s, v9.4s\n"
3537 "sqadd v22.4s, v22.4s, v14.4s\n"
3538 "sqadd v23.4s, v23.4s, v15.4s\n"
3539 "and v8.16b, v24.16b, v11.16b\n"
3540 "and v9.16b, v25.16b, v12.16b\n"
3541 "and v14.16b, v26.16b, v11.16b\n"
3542 "and v15.16b, v27.16b, v12.16b\n"
3543 "sshr v8.4s, v8.4s, #31\n"
3544 "sshr v9.4s, v9.4s, #31\n"
3545 "sshr v14.4s, v14.4s, #31\n"
3546 "sshr v15.4s, v15.4s, #31\n"
3547 "sqadd v24.4s, v24.4s, v8.4s\n"
3548 "sqadd v25.4s, v25.4s, v9.4s\n"
3549 "sqadd v26.4s, v26.4s, v14.4s\n"
3550 "sqadd v27.4s, v27.4s, v15.4s\n"
3551 "and v8.16b, v28.16b, v11.16b\n"
3552 "and v9.16b, v29.16b, v12.16b\n"
3553 "and v14.16b, v30.16b, v11.16b\n"
3554 "and v15.16b, v31.16b, v12.16b\n"
3555 "sshr v8.4s, v8.4s, #31\n"
3556 "sshr v9.4s, v9.4s, #31\n"
3557 "sshr v14.4s, v14.4s, #31\n"
3558 "sshr v15.4s, v15.4s, #31\n"
3559 "sqadd v28.4s, v28.4s, v8.4s\n"
3560 "sqadd v29.4s, v29.4s, v9.4s\n"
3561 "sqadd v30.4s, v30.4s, v14.4s\n"
3562 "sqadd v31.4s, v31.4s, v15.4s\n"
3563 #endif
3564 // At this point we have reduced the problem of correctly implementing
3565 // rounding divide-by-power-of-two, to what the SRSHL instruction can
3566 // do.
3567 "srshl v16.4s, v16.4s, v11.4s\n"
3568 "srshl v17.4s, v17.4s, v12.4s\n"
3569 "srshl v18.4s, v18.4s, v11.4s\n"
3570 "srshl v19.4s, v19.4s, v12.4s\n"
3571 "srshl v20.4s, v20.4s, v11.4s\n"
3572 "srshl v21.4s, v21.4s, v12.4s\n"
3573 "srshl v22.4s, v22.4s, v11.4s\n"
3574 "srshl v23.4s, v23.4s, v12.4s\n"
3575 "srshl v24.4s, v24.4s, v11.4s\n"
3576 "srshl v25.4s, v25.4s, v12.4s\n"
3577 "srshl v26.4s, v26.4s, v11.4s\n"
3578 "srshl v27.4s, v27.4s, v12.4s\n"
3579 "srshl v28.4s, v28.4s, v11.4s\n"
3580 "srshl v29.4s, v29.4s, v12.4s\n"
3581 "srshl v30.4s, v30.4s, v11.4s\n"
3582 "srshl v31.4s, v31.4s, v12.4s\n"
3583
3584 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
3585 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
3586 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
3587 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
3588
3589 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
3590
3591 // Cast-and-saturate from int32 to int16
3592 "sqxtn v16.4h, v16.4s\n"
3593 "sqxtn2 v16.8h, v17.4s\n"
3594 "sqxtn v17.4h, v18.4s\n"
3595 "sqxtn2 v17.8h, v19.4s\n"
3596 "sqxtn v18.4h, v20.4s\n"
3597 "sqxtn2 v18.8h, v21.4s\n"
3598 "sqxtn v19.4h, v22.4s\n"
3599 "sqxtn2 v19.8h, v23.4s\n"
3600 "sqxtn v20.4h, v24.4s\n"
3601 "sqxtn2 v20.8h, v25.4s\n"
3602 "sqxtn v21.4h, v26.4s\n"
3603 "sqxtn2 v21.8h, v27.4s\n"
3604 "sqxtn v22.4h, v28.4s\n"
3605 "sqxtn2 v22.8h, v29.4s\n"
3606 "sqxtn v23.4h, v30.4s\n"
3607 "sqxtn2 v23.8h, v31.4s\n"
3608
3609 // At this point, v24 -- v31 aren't used anymore for the current block,
3610 // so we can start clearing these accumulators for the next block
3611 // (next iteration of the main loop).
3612 RUY_MAKE_ZERO(v24)
3613 RUY_MAKE_ZERO(v25)
3614 RUY_MAKE_ZERO(v26)
3615 RUY_MAKE_ZERO(v27)
3616 RUY_MAKE_ZERO(v28)
3617 RUY_MAKE_ZERO(v29)
3618 RUY_MAKE_ZERO(v30)
3619 RUY_MAKE_ZERO(v31)
3620
3621 // Add the destination zero point
3622 "dup v14.8h, v13.h[4]\n"
3623 "add v16.8h, v16.8h, v14.8h\n"
3624 "add v17.8h, v17.8h, v14.8h\n"
3625 "add v18.8h, v18.8h, v14.8h\n"
3626 "add v19.8h, v19.8h, v14.8h\n"
3627 "add v20.8h, v20.8h, v14.8h\n"
3628 "add v21.8h, v21.8h, v14.8h\n"
3629 "add v22.8h, v22.8h, v14.8h\n"
3630 "add v23.8h, v23.8h, v14.8h\n"
3631
3632 // Cast-and-saturate from int16 to uint8
3633 "sqxtun v16.8b, v16.8h\n"
3634 "sqxtun2 v16.16b, v17.8h\n"
3635 "sqxtun v17.8b, v18.8h\n"
3636 "sqxtun2 v17.16b, v19.8h\n"
3637 "sqxtun v18.8b, v20.8h\n"
3638 "sqxtun2 v18.16b, v21.8h\n"
3639 "sqxtun v19.8b, v22.8h\n"
3640 "sqxtun2 v19.16b, v23.8h\n"
3641
3642 // Load the clamp_min, clamp_max bounds
3643 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
3644 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
3645 "dup v14.16b, w2\n" // clamp_min
3646 "dup v15.16b, w3\n" // clamp_max
3647
3648 // Apply the clamp_min bound
3649 "umax v16.16b, v16.16b, v14.16b\n"
3650 "umax v17.16b, v17.16b, v14.16b\n"
3651 "umax v18.16b, v18.16b, v14.16b\n"
3652 "umax v19.16b, v19.16b, v14.16b\n"
3653
3654 // Apply the clamp_max bound
3655 "umin v16.16b, v16.16b, v15.16b\n"
3656 "umin v17.16b, v17.16b, v15.16b\n"
3657 "umin v18.16b, v18.16b, v15.16b\n"
3658 "umin v19.16b, v19.16b, v15.16b\n"
3659
3660 // Make it so that all of the final 8bit values are stored in the
3661 // first 64bits of 128bit NEON registers, so they can be stored
3662 // by 64bit st1 store instructions with byte alignment.
3663 "dup d20, v16.d[1]\n"
3664 "dup d21, v17.d[1]\n"
3665 "dup d22, v18.d[1]\n"
3666 "dup d23, v19.d[1]\n"
3667
3668 // Compute how much of the 8x8 block of destination 8bit values that
3669 // we have computed, fit in the destination matrix. Typically, all of
3670 // it fits, but when the destination matrix shape is not a multiple
3671 // of 8x8, there are some 8x8 blocks along the boundaries that do
3672 // not fit entirely.
3673 "sub w1, %w[dst_rows], %w[row]\n"
3674 "sub w2, %w[dst_cols], %w[col]\n"
3675 "mov w3, #8\n"
3676 "cmp w1, #8\n"
3677 // Compute w1 = how many rows of the 8x8 block fit
3678 "csel w1, w1, w3, le\n"
3679 "cmp w2, #8\n"
3680 // Compute w2 = how many cols of the 8x8 block fit
3681 "csel w2, w2, w3, le\n"
3682
3683 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
3684 "cmp w1, w3\n"
3685 "ccmp w2, w3, 0, eq\n"
3686 // Yes, all of the 8x8 block fits, go to fast path.
3687 "beq 30f\n"
3688 // Not all of the 8x8 block fits.
3689 // Set (x3 address, x4 stride) to write to dst_tmp_buf
3690 "mov x3, %[dst_tmp_buf]\n"
3691 "mov x4, #8\n"
3692 "b 31f\n"
3693 "30:\n"
3694 // Yes, all of the 8x8 block fits.
3695 // Set (x3 address, x4 stride) to write directly to destination matrix.
3696 "mov x3, %[dst_ptr]\n"
3697 "mov x4, x11\n"
3698 "31:\n"
3699
3700 // Write our 8bit values to the destination described by
3701 // (x3 address, x4 stride).
3702 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3703 "st1 {v16.8b}, [x3], x4\n"
3704 RUY_MAKE_ZERO(v16)
3705 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3706 "st1 {v20.8b}, [x3], x4\n"
3707 RUY_MAKE_ZERO(v20)
3708 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3709 "st1 {v17.8b}, [x3], x4\n"
3710 RUY_MAKE_ZERO(v17)
3711 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3712 "st1 {v21.8b}, [x3], x4\n"
3713 RUY_MAKE_ZERO(v21)
3714 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3715 "st1 {v18.8b}, [x3], x4\n"
3716 RUY_MAKE_ZERO(v18)
3717 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3718 "st1 {v22.8b}, [x3], x4\n"
3719 RUY_MAKE_ZERO(v22)
3720 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3721 "st1 {v19.8b}, [x3], x4\n"
3722 RUY_MAKE_ZERO(v19)
3723 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3724 "st1 {v23.8b}, [x3], x4\n"
3725 RUY_MAKE_ZERO(v23)
3726
3727 // For the next block: perform the first few multiply-adds on the data
3728 // that we have already loaded.
3729 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3730 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3731 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3732 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3733
3734 // If all of the 8x8 block fits, we just finished writing it to the
3735 // destination, so we skip the next part.
3736 "beq 41f\n"
3737 // Not all of the 8x8 block fits in the destination matrix. We just
3738 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
3739 // it to copy into the destination matrix the part that fits.
3740 "mov x3, %[dst_tmp_buf]\n"
3741 "mov x4, %[dst_ptr]\n"
3742 "mov w6, #0\n"
3743 "50:\n"
3744 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
3745 "mov w5, #0\n"
3746 "51:\n"
3747 "ldrb w7, [x3, w5, uxtw]\n"
3748 "strb w7, [x4, w5, uxtw]\n"
3749 "add w5, w5, #1\n"
3750 "cmp w5, w1\n"
3751 "blt 51b\n"
3752 "add w6, w6, #1\n"
3753 "add x3, x3, #8\n"
3754 "add x4, x4, x11\n"
3755 "cmp w6, w2\n"
3756 "blt 50b\n"
3757 "41:\n"
3758 "add %[dst_ptr], %[dst_ptr], #8\n"
3759 // At this point we have completely finished writing values to the
3760 // destination matrix for the current block.
3761
3762 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
3763
3764 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
3765
3766 // Cast-and-saturate from int32 to int16
3767 "sqxtn v16.4h, v16.4s\n"
3768 "sqxtn2 v16.8h, v17.4s\n"
3769 "sqxtn v17.4h, v18.4s\n"
3770 "sqxtn2 v17.8h, v19.4s\n"
3771 "sqxtn v18.4h, v20.4s\n"
3772 "sqxtn2 v18.8h, v21.4s\n"
3773 "sqxtn v19.4h, v22.4s\n"
3774 "sqxtn2 v19.8h, v23.4s\n"
3775 "sqxtn v20.4h, v24.4s\n"
3776 "sqxtn2 v20.8h, v25.4s\n"
3777 "sqxtn v21.4h, v26.4s\n"
3778 "sqxtn2 v21.8h, v27.4s\n"
3779 "sqxtn v22.4h, v28.4s\n"
3780 "sqxtn2 v22.8h, v29.4s\n"
3781 "sqxtn v23.4h, v30.4s\n"
3782 "sqxtn2 v23.8h, v31.4s\n"
3783
3784 // At this point, v24 -- v31 aren't used anymore for the current block,
3785 // so we can start clearing these accumulators for the next block
3786 // (next iteration of the main loop).
3787 RUY_MAKE_ZERO(v24)
3788 RUY_MAKE_ZERO(v25)
3789 RUY_MAKE_ZERO(v26)
3790 RUY_MAKE_ZERO(v27)
3791 RUY_MAKE_ZERO(v28)
3792 RUY_MAKE_ZERO(v29)
3793 RUY_MAKE_ZERO(v30)
3794 RUY_MAKE_ZERO(v31)
3795
3796 // Add the destination zero point
3797 "dup v14.8h, v13.h[4]\n"
3798 "add v16.8h, v16.8h, v14.8h\n"
3799 "add v17.8h, v17.8h, v14.8h\n"
3800 "add v18.8h, v18.8h, v14.8h\n"
3801 "add v19.8h, v19.8h, v14.8h\n"
3802 "add v20.8h, v20.8h, v14.8h\n"
3803 "add v21.8h, v21.8h, v14.8h\n"
3804 "add v22.8h, v22.8h, v14.8h\n"
3805 "add v23.8h, v23.8h, v14.8h\n"
3806
3807 // Cast-and-saturate from int16 to uint8
3808 "sqxtn v16.8b, v16.8h\n"
3809 "sqxtn2 v16.16b, v17.8h\n"
3810 "sqxtn v17.8b, v18.8h\n"
3811 "sqxtn2 v17.16b, v19.8h\n"
3812 "sqxtn v18.8b, v20.8h\n"
3813 "sqxtn2 v18.16b, v21.8h\n"
3814 "sqxtn v19.8b, v22.8h\n"
3815 "sqxtn2 v19.16b, v23.8h\n"
3816
3817 // Load the clamp_min, clamp_max bounds
3818 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
3819 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
3820 "dup v14.16b, w2\n" // clamp_min
3821 "dup v15.16b, w3\n" // clamp_max
3822
3823 // Apply the clamp_min bound
3824 "smax v16.16b, v16.16b, v14.16b\n"
3825 "smax v17.16b, v17.16b, v14.16b\n"
3826 "smax v18.16b, v18.16b, v14.16b\n"
3827 "smax v19.16b, v19.16b, v14.16b\n"
3828
3829 // Apply the clamp_max bound
3830 "smin v16.16b, v16.16b, v15.16b\n"
3831 "smin v17.16b, v17.16b, v15.16b\n"
3832 "smin v18.16b, v18.16b, v15.16b\n"
3833 "smin v19.16b, v19.16b, v15.16b\n"
3834
3835 // Make it so that all of the final 8bit values are stored in the
3836 // first 64bits of 128bit NEON registers, so they can be stored
3837 // by 64bit st1 store instructions with byte alignment.
3838 "dup d20, v16.d[1]\n"
3839 "dup d21, v17.d[1]\n"
3840 "dup d22, v18.d[1]\n"
3841 "dup d23, v19.d[1]\n"
3842
3843 // Compute how much of the 8x8 block of destination 8bit values that
3844 // we have computed, fit in the destination matrix. Typically, all of
3845 // it fits, but when the destination matrix shape is not a multiple
3846 // of 8x8, there are some 8x8 blocks along the boundaries that do
3847 // not fit entirely.
3848 "sub w1, %w[dst_rows], %w[row]\n"
3849 "sub w2, %w[dst_cols], %w[col]\n"
3850 "mov w3, #8\n"
3851 "cmp w1, #8\n"
3852 // Compute w1 = how many rows of the 8x8 block fit
3853 "csel w1, w1, w3, le\n"
3854 "cmp w2, #8\n"
3855 // Compute w2 = how many cols of the 8x8 block fit
3856 "csel w2, w2, w3, le\n"
3857
3858 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
3859 "cmp w1, w3\n"
3860 "ccmp w2, w3, 0, eq\n"
3861 // Yes, all of the 8x8 block fits, go to fast path.
3862 "beq 130f\n"
3863 // Not all of the 8x8 block fits.
3864 // Set (x3 address, x4 stride) to write to dst_tmp_buf
3865 "mov x3, %[dst_tmp_buf]\n"
3866 "mov x4, #8\n"
3867 "b 131f\n"
3868 "130:\n"
3869 // Yes, all of the 8x8 block fits.
3870 // Set (x3 address, x4 stride) to write directly to destination matrix.
3871 "mov x3, %[dst_ptr]\n"
3872 "mov x4, x11\n"
3873 "131:\n"
3874
3875 // Write our 8bit values to the destination described by
3876 // (x3 address, x4 stride).
3877 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3878 "st1 {v16.8b}, [x3], x4\n"
3879 RUY_MAKE_ZERO(v16)
3880 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3881 "st1 {v20.8b}, [x3], x4\n"
3882 RUY_MAKE_ZERO(v20)
3883 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3884 "st1 {v17.8b}, [x3], x4\n"
3885 RUY_MAKE_ZERO(v17)
3886 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3887 "st1 {v21.8b}, [x3], x4\n"
3888 RUY_MAKE_ZERO(v21)
3889 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3890 "st1 {v18.8b}, [x3], x4\n"
3891 RUY_MAKE_ZERO(v18)
3892 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3893 "st1 {v22.8b}, [x3], x4\n"
3894 RUY_MAKE_ZERO(v22)
3895 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3896 "st1 {v19.8b}, [x3], x4\n"
3897 RUY_MAKE_ZERO(v19)
3898 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3899 "st1 {v23.8b}, [x3], x4\n"
3900 RUY_MAKE_ZERO(v23)
3901
3902 // For the next block: perform the first few multiply-adds on the data
3903 // that we have already loaded.
3904 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3905 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3906 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3907 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3908
3909 // If all of the 8x8 block fits, we just finished writing it to the
3910 // destination, so we skip the next part.
3911 "beq 141f\n"
3912 // Not all of the 8x8 block fits in the destination matrix. We just
3913 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
3914 // it to copy into the destination matrix the part that fits.
3915 "mov x3, %[dst_tmp_buf]\n"
3916 "mov x4, %[dst_ptr]\n"
3917 "mov w6, #0\n"
3918 "150:\n"
3919 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
3920 "mov w5, #0\n"
3921 "151:\n"
3922 "ldrb w7, [x3, w5, uxtw]\n"
3923 "strb w7, [x4, w5, uxtw]\n"
3924 "add w5, w5, #1\n"
3925 "cmp w5, w1\n"
3926 "blt 151b\n"
3927 "add w6, w6, #1\n"
3928 "add x3, x3, #8\n"
3929 "add x4, x4, x11\n"
3930 "cmp w6, w2\n"
3931 "blt 150b\n"
3932 "141:\n"
3933 "add %[dst_ptr], %[dst_ptr], #8\n"
3934 // At this point we have completely finished writing values to the
3935 // destination matrix for the current block.
3936
3937 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
3938
3939 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
3940
3941 // Add the destination zero point
3942 "dup v14.8h, v13.h[4]\n"
3943 "saddw v16.4s, v16.4s, v14.4h\n"
3944 "saddw v17.4s, v17.4s, v14.4h\n"
3945 "saddw v18.4s, v18.4s, v14.4h\n"
3946 "saddw v19.4s, v19.4s, v14.4h\n"
3947 "saddw v20.4s, v20.4s, v14.4h\n"
3948 "saddw v21.4s, v21.4s, v14.4h\n"
3949 "saddw v22.4s, v22.4s, v14.4h\n"
3950 "saddw v23.4s, v23.4s, v14.4h\n"
3951 "saddw v24.4s, v24.4s, v14.4h\n"
3952 "saddw v25.4s, v25.4s, v14.4h\n"
3953 "saddw v26.4s, v26.4s, v14.4h\n"
3954 "saddw v27.4s, v27.4s, v14.4h\n"
3955 "saddw v28.4s, v28.4s, v14.4h\n"
3956 "saddw v29.4s, v29.4s, v14.4h\n"
3957 "saddw v30.4s, v30.4s, v14.4h\n"
3958 "saddw v31.4s, v31.4s, v14.4h\n"
3959
3960 // Cast-and-saturate from int32 to int16
3961 "sqxtn v16.4h, v16.4s\n"
3962 "sqxtn2 v16.8h, v17.4s\n"
3963 "sqxtn v17.4h, v18.4s\n"
3964 "sqxtn2 v17.8h, v19.4s\n"
3965 "sqxtn v18.4h, v20.4s\n"
3966 "sqxtn2 v18.8h, v21.4s\n"
3967 "sqxtn v19.4h, v22.4s\n"
3968 "sqxtn2 v19.8h, v23.4s\n"
3969 "sqxtn v20.4h, v24.4s\n"
3970 "sqxtn2 v20.8h, v25.4s\n"
3971 "sqxtn v21.4h, v26.4s\n"
3972 "sqxtn2 v21.8h, v27.4s\n"
3973 "sqxtn v22.4h, v28.4s\n"
3974 "sqxtn2 v22.8h, v29.4s\n"
3975 "sqxtn v23.4h, v30.4s\n"
3976 "sqxtn2 v23.8h, v31.4s\n"
3977
3978 // At this point, v24 -- v31 aren't used anymore for the current block,
3979 // so we can start clearing these accumulators for the next block
3980 // (next iteration of the main loop).
3981 RUY_MAKE_ZERO(v24)
3982 RUY_MAKE_ZERO(v25)
3983 RUY_MAKE_ZERO(v26)
3984 RUY_MAKE_ZERO(v27)
3985 RUY_MAKE_ZERO(v28)
3986 RUY_MAKE_ZERO(v29)
3987 RUY_MAKE_ZERO(v30)
3988 RUY_MAKE_ZERO(v31)
3989
3990 // Load the clamp_min, clamp_max bounds
3991 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
3992 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
3993 "dup v14.8h, w2\n" // clamp_min
3994 "dup v15.8h, w3\n" // clamp_max
3995
3996 // Apply the clamp_min bound
3997 "smax v16.8h, v16.8h, v14.8h\n"
3998 "smax v17.8h, v17.8h, v14.8h\n"
3999 "smax v18.8h, v18.8h, v14.8h\n"
4000 "smax v19.8h, v19.8h, v14.8h\n"
4001 "smax v20.8h, v20.8h, v14.8h\n"
4002 "smax v21.8h, v21.8h, v14.8h\n"
4003 "smax v22.8h, v22.8h, v14.8h\n"
4004 "smax v23.8h, v23.8h, v14.8h\n"
4005 // Apply the clamp_max bound
4006 "smin v16.8h, v16.8h, v15.8h\n"
4007 "smin v17.8h, v17.8h, v15.8h\n"
4008 "smin v18.8h, v18.8h, v15.8h\n"
4009 "smin v19.8h, v19.8h, v15.8h\n"
4010 "smin v20.8h, v20.8h, v15.8h\n"
4011 "smin v21.8h, v21.8h, v15.8h\n"
4012 "smin v22.8h, v22.8h, v15.8h\n"
4013 "smin v23.8h, v23.8h, v15.8h\n"
4014
4015 // Compute how much of the 8x8 block of destination 16bit values that
4016 // we have computed, fit in the destination matrix. Typically, all of
4017 // it fits, but when the destination matrix shape is not a multiple
4018 // of 8x8, there are some 8x8 blocks along the boundaries that do
4019 // not fit entirely.
4020 "sub w1, %w[dst_rows], %w[row]\n"
4021 "sub w2, %w[dst_cols], %w[col]\n"
4022 "mov w3, #8\n"
4023 "cmp w1, #8\n"
4024 // Compute w1 = how many rows of the 8x8 block fit
4025 "csel w1, w1, w3, le\n"
4026 "cmp w2, #8\n"
4027 // Compute w1 = how many rows of the 8x8 block fit
4028 "csel w2, w2, w3, le\n"
4029
4030 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
4031 "cmp w1, w3\n"
4032 "ccmp w2, w3, 0, eq\n"
4033 // Yes, all of the 8x8 block fits, go to fast path.
4034 "beq 230f\n"
4035 // Not all of the 8x8 block fits.
4036 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4037 "mov x3, %[dst_tmp_buf]\n"
4038 "mov x4, #16\n"
4039 "b 231f\n"
4040 "230:\n"
4041 // Yes, all of the 8x8 block fits.
4042 // Set (x3 address, x4 stride) to write directly to destination matrix.
4043 "mov x3, %[dst_ptr]\n"
4044 "mov x4, x11\n"
4045 "231:\n"
4046
4047 // Write our 16bit values to the destination described by
4048 // (x3 address, x4 stride).
4049 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4050 "st1 {v16.8h}, [x3], x4\n"
4051 RUY_MAKE_ZERO(v16)
4052 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4053 "st1 {v17.8h}, [x3], x4\n"
4054 RUY_MAKE_ZERO(v17)
4055 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4056 "st1 {v18.8h}, [x3], x4\n"
4057 RUY_MAKE_ZERO(v18)
4058 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4059 "st1 {v19.8h}, [x3], x4\n"
4060 RUY_MAKE_ZERO(v19)
4061 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4062 "st1 {v20.8h}, [x3], x4\n"
4063 RUY_MAKE_ZERO(v20)
4064 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4065 "st1 {v21.8h}, [x3], x4\n"
4066 RUY_MAKE_ZERO(v21)
4067 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4068 "st1 {v22.8h}, [x3], x4\n"
4069 RUY_MAKE_ZERO(v22)
4070 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4071 "st1 {v23.8h}, [x3], x4\n"
4072 RUY_MAKE_ZERO(v23)
4073
4074 // For the next block: perform the first few multiply-adds on the data
4075 // that we have already loaded.
4076 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4077 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
4078 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
4079 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
4080
4081 // If all of the 8x8 block fits, we just finished writing it to the
4082 // destination, so we skip the next part.
4083 "beq 241f\n"
4084 // Not all of the 8x8 block fits in the destination matrix. We just
4085 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4086 // it to copy into the destination matrix the part that fits.
4087 "mov x3, %[dst_tmp_buf]\n"
4088 "mov x4, %[dst_ptr]\n"
4089 "mov w6, #0\n"
4090 "250:\n"
4091 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4092 "mov w5, #0\n"
4093 "251:\n"
4094 "ldrsh w7, [x3, x5, lsl #1]\n"
4095 "strh w7, [x4, x5, lsl #1]\n"
4096 "add w5, w5, #1\n"
4097 "cmp w5, w1\n"
4098 "blt 251b\n"
4099 "add w6, w6, #1\n"
4100 "add x3, x3, #16\n"
4101 "add x4, x4, x11\n"
4102 "cmp w6, w2\n"
4103 "blt 250b\n"
4104 "241:\n"
4105 "add %[dst_ptr], %[dst_ptr], #16\n"
4106 // At this point we have completely finished writing values to the
4107 // destination matrix for the current block.
4108
4109 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4110
4111 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
4112
4113 // Since the store type is the same as the accum type, no need for
4114 // downcast. There's also no need for clamp by min/max.
4115
4116 // Compute how much of the 8x8 block of destination 32it values that
4117 // we have computed, fit in the destination matrix. Typically, all of
4118 // it fits, but when the destination matrix shape is not a multiple
4119 // of 8x8, there are some 8x8 blocks along the boundaries that do
4120 // not fit entirely.
4121 "sub w1, %w[dst_rows], %w[row]\n"
4122 "sub w2, %w[dst_cols], %w[col]\n"
4123 "mov w3, #8\n"
4124 "cmp w1, #8\n"
4125 // Compute w1 = how many rows of the 8x8 block fit
4126 "csel w1, w1, w3, le\n"
4127 "cmp w2, #8\n"
4128 // Compute w1 = how many rows of the 8x8 block fit
4129 "csel w2, w2, w3, le\n"
4130
4131 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
4132 "cmp w1, w3\n"
4133 "ccmp w2, w3, 0, eq\n"
4134 // Yes, all of the 8x8 block fits, go to fast path.
4135 "beq 330f\n"
4136 // Not all of the 8x8 block fits.
4137 // Write to dst_tmp_buf
4138 "mov x3, %[dst_tmp_buf]\n"
4139 "st1 {v16.4s}, [x3], #16\n"
4140 RUY_MAKE_ZERO(v16)
4141 "st1 {v17.4s}, [x3], #16\n"
4142 RUY_MAKE_ZERO(v17)
4143 "st1 {v18.4s}, [x3], #16\n"
4144 RUY_MAKE_ZERO(v18)
4145 "st1 {v19.4s}, [x3], #16\n"
4146 RUY_MAKE_ZERO(v19)
4147 "st1 {v20.4s}, [x3], #16\n"
4148 RUY_MAKE_ZERO(v20)
4149 "st1 {v21.4s}, [x3], #16\n"
4150 RUY_MAKE_ZERO(v21)
4151 "st1 {v22.4s}, [x3], #16\n"
4152 RUY_MAKE_ZERO(v22)
4153 "st1 {v23.4s}, [x3], #16\n"
4154 RUY_MAKE_ZERO(v23)
4155 "st1 {v24.4s}, [x3], #16\n"
4156 RUY_MAKE_ZERO(v24)
4157 "st1 {v25.4s}, [x3], #16\n"
4158 RUY_MAKE_ZERO(v25)
4159 "st1 {v26.4s}, [x3], #16\n"
4160 RUY_MAKE_ZERO(v26)
4161 "st1 {v27.4s}, [x3], #16\n"
4162 RUY_MAKE_ZERO(v27)
4163 "st1 {v28.4s}, [x3], #16\n"
4164 RUY_MAKE_ZERO(v28)
4165 "st1 {v29.4s}, [x3], #16\n"
4166 RUY_MAKE_ZERO(v29)
4167 "st1 {v30.4s}, [x3], #16\n"
4168 RUY_MAKE_ZERO(v30)
4169 "st1 {v31.4s}, [x3], #16\n"
4170 RUY_MAKE_ZERO(v31)
4171
4172 "b 331f\n"
4173
4174 "330:\n"
4175 // Yes, all of the 8x8 block fits.
4176 "mov x4, %[dst_ptr]\n"
4177 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4178 "mov x3, x4\n"
4179 "st1 {v16.4s, v17.4s}, [x3], #32\n"
4180 RUY_MAKE_ZERO(v16)
4181 RUY_MAKE_ZERO(v17)
4182 "add x4, x4, x11\n"
4183 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4184 "mov x3, x4\n"
4185 "st1 {v18.4s, v19.4s}, [x3], #32\n"
4186 RUY_MAKE_ZERO(v18)
4187 RUY_MAKE_ZERO(v19)
4188 "add x4, x4, x11\n"
4189 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4190 "mov x3, x4\n"
4191 "st1 {v20.4s, v21.4s}, [x3], #32\n"
4192 RUY_MAKE_ZERO(v20)
4193 RUY_MAKE_ZERO(v21)
4194 "add x4, x4, x11\n"
4195 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4196 "mov x3, x4\n"
4197 "st1 {v22.4s, v23.4s}, [x3], #32\n"
4198 RUY_MAKE_ZERO(v22)
4199 RUY_MAKE_ZERO(v23)
4200 "add x4, x4, x11\n"
4201 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4202 "mov x3, x4\n"
4203 "st1 {v24.4s, v25.4s}, [x3], #32\n"
4204 RUY_MAKE_ZERO(v24)
4205 RUY_MAKE_ZERO(v25)
4206 "add x4, x4, x11\n"
4207 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4208 "mov x3, x4\n"
4209 "st1 {v26.4s, v27.4s}, [x3], #32\n"
4210 RUY_MAKE_ZERO(v26)
4211 RUY_MAKE_ZERO(v27)
4212 "add x4, x4, x11\n"
4213 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4214 "mov x3, x4\n"
4215 "st1 {v28.4s, v29.4s}, [x3], #32\n"
4216 RUY_MAKE_ZERO(v28)
4217 RUY_MAKE_ZERO(v29)
4218 "add x4, x4, x11\n"
4219 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4220 "mov x3, x4\n"
4221 "st1 {v30.4s, v31.4s}, [x3], #32\n"
4222 RUY_MAKE_ZERO(v30)
4223 RUY_MAKE_ZERO(v31)
4224
4225 "331:\n"
4226
4227 // For the next block: perform the first few multiply-adds on the data
4228 // that we have already loaded.
4229 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4230 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
4231 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
4232 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
4233
4234 // If all of the 8x8 block fits, we just finished writing it to the
4235 // destination, so we skip the next part.
4236 "beq 341f\n"
4237
4238 // Not all of the 8x8 block fits in the destination matrix. We just
4239 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4240 // it to copy into the destination matrix the part that fits.
4241 "mov x3, %[dst_tmp_buf]\n"
4242 "mov x4, %[dst_ptr]\n"
4243 "mov w6, #0\n"
4244 "350:\n"
4245 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4246 "mov w5, #0\n"
4247 "351:\n"
4248 "ldr w7, [x3, x5, lsl #2]\n"
4249 "str w7, [x4, x5, lsl #2]\n"
4250 "add w5, w5, #1\n"
4251 "cmp w5, w1\n"
4252 "blt 351b\n"
4253 "add w6, w6, #1\n"
4254 "add x3, x3, #32\n"
4255 "add x4, x4, x11\n"
4256 "cmp w6, w2\n"
4257 "blt 350b\n"
4258 "341:\n"
4259 "add %[dst_ptr], %[dst_ptr], #32\n"
4260 // At this point we have completely finished writing values to the
4261 // destination matrix for the current block.
4262
4263 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
4264
4265 // Reload some params --- we had used x5 -- x7 for a few other things
4266 // since the last time we had loaded them.
4267 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
4268 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
4269 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
4270
4271 // Move to the next block of the destination matrix, for the next iter
4272 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
4273 // been updated earlier.
4274 // Have we reached the end row?
4275 "cmp %w[row], w7\n"
4276 "beq 20f\n" // yes, end row.
4277 // Not end row. Move to the next row.
4278 "add %w[row], %w[row], #8\n"
4279 "b 21f\n"
4280 "20:\n"
4281 // Was already at end row.
4282 "mov %w[row], w6\n" // Move back to first row.
4283 "add %w[col], %w[col], #8\n" // Move to the next column.
4284 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
4285 "mov %[dst_ptr], %[dst_col_ptr]\n"
4286 "21:\n"
4287
4288 // Main loop exit condition: have we hit the end column?
4289 "cmp %w[col], w8\n"
4290
4291 // w1 is the number of levels of depth that we have already loaded
4292 // LHS and RHS data for. Corresponding to the initial ld1 instructions
4293 // above, this is currently 4.
4294 "mov w1, #4\n"
4295
4296 "ble 1b\n"
4297
4298 // clang-format on
4299
4300 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
4301 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4302 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
4303 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
4304 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
4305 [dst_type_id] "r"(params.dst_type_id)
4306 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
4307 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
4308 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
4309 "v26", "v27", "v28", "v29", "v30", "v31");
4310 }
4311
4312 // Similar to the above 8-bit dotprod kernel, but specialized for the case of
4313 // RHS cols == 1.
4314 // Relevant target CPUs for this kernel include ARM Cortex-A76,
4315 // since these are 64-bit, out-of-order and with dotprod support.
Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8,8> & params)4316 void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) {
4317 profiler::ScopeLabel label(
4318 "Kernel (kNeonDotprod, optimized for out-of-order cores)");
4319
4320 CheckOffsetsInKernelParams8bit(params);
4321
4322 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
4323 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
4324 const std::int8_t* lhs_ptr = lhs_col_ptr;
4325 const std::int8_t* rhs_ptr = rhs_col_ptr;
4326 void* dst_col_ptr = params.dst_base_ptr;
4327 void* dst_ptr = dst_col_ptr;
4328 int row = params.start_row;
4329 int col = params.start_col;
4330
4331 // The asm kernel below has the following NEON register allocation:
4332 //
4333 // v16 -- v31 are int32 accumulators.
4334 // During accumulation, v0 -- v15 are used to load int8 data from LHS and
4335 // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
4336 // v3 are used to load a 4x8 block of RHS, like this:
4337 //
4338 // int8 RHS 4x1 block
4339 // /-------\
4340 // |v2.b[0]|
4341 // | ... |
4342 // |v2.b[3]|
4343 // \-------/
4344 // int8 LHS 8x4 block
4345 // /---------------------\ /--------\
4346 // |v0.b[0] ... v0.b[3] | |v16.s[0]|
4347 // | ... ... | | ... |
4348 // |v0.b[12] ... v0.b[15]| |v16.s[3]|
4349 // |v1.b[0] ... v1.b[3] | |v17.s[0]|
4350 // | ... ... | | ... |
4351 // |v1.b[12] ... v1.b[15]| |v17.s[3]|
4352 // \---------------------/ \--------/
4353 // int32 accumulators 8x1 block
4354 //
4355 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
4356 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
4357 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
4358 //
4359 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
4360 // unused, and v8 -- v15 are used for loading parameters used for the
4361 // post-accumulation part of the kernel.
4362 asm volatile(
4363 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
4364
4365 // clang-format off
4366
4367 // Load some parameters into registers.
4368 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
4369 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
4370 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
4371 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
4372 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
4373 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
4374 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
4375 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
4376
4377 // Load the first 32 bytes of LHS and RHS data.
4378 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4379 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4380 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4381 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4382
4383 // Clear accumulators.
4384 RUY_MAKE_ZERO(v16)
4385 RUY_MAKE_ZERO(v17)
4386
4387 // w1 is the number of levels of depth that we have already loaded
4388 // LHS and RHS data for. Corresponding to the initial ld1 instructions
4389 // above, this is currently 4.
4390 "mov w1, #4\n"
4391
4392 // Perform the first few multiply-adds on the data that we have already
4393 // loaded.
4394 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4395
4396 // Main loop of the whole GEMM, over rows and columns of the
4397 // destination matrix.
4398 "1:\n"
4399
4400 // Ordinary kernel inner loop (over depth), the simpler loop that the
4401 // above was an equivalent 4x-partially-unrolled version of.
4402
4403 // Reminder - w1 is how many levels of depth we have already loaded
4404 // data for, w12 is the total depth.
4405 "cmp w1, w12\n"
4406 "beq 79f\n"
4407
4408 "2:\n"
4409
4410 // Because of the data that we have already loaded, we can start the
4411 // loop body right away with some multiply-adds.
4412 // Each iteration of this loop advances by 4 levels of depth.
4413 "add w1, w1, #4\n"
4414 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4415 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
4416 // Loop termination condition.
4417 "cmp w1, w12\n"
4418 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4419 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4420 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4421 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4422
4423 "blt 2b\n"
4424
4425 "79:\n"
4426 // End of the inner loop on depth. Now perform the remaining
4427 // multiply-adds of the last 4 levels of depth, for which the LHS
4428 // and RHS data is already loaded.
4429
4430 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
4431
4432 // End of accumulation. The registers v16 -- v31 contain the final
4433 // int32 accumulator values of the current 8x8 destination block.
4434 // We now have to compute the final 8-bit values from these int32
4435 // accumulators, and advance to the next 8x8 block. We intertwine
4436 // these two aspects whenever possible for optimal pipelining, both
4437 // at the data flow level (prefetch data for next block as early as
4438 // possible) and instruction pipelining level (some of the next-block
4439 // work can dual-issue with some of the final work on the current
4440 // block).
4441
4442 // Logic to advance to the next block in preparation for the next
4443 // iteration of the main loop. For now, we only want to compute
4444 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
4445 // not yet ready to update the values of row and col, as we still need
4446 // the current values for the rest of the work on the current block.
4447
4448 "cmp %w[row], w7\n" // Have we finished the last row?
4449 "bge 4f\n" // If finished last row, go to 4
4450 // Not finished last row: then advance to next row.
4451 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
4452 "b 5f\n"
4453 "4:\n" // Finished last row...
4454 "mov %[lhs_col_ptr], x5\n" // Go back to first row
4455 // Now we need to advance to the next column. If we already
4456 // finished the last column, then in principle we are done, however
4457 // we can't just return here, as we need to allow the end work of the
4458 // current block to complete. The good news is that at this point it
4459 // doesn't matter what data we load for the next column, since
4460 // we will exit from the main loop below before actually storing
4461 // anything computed from that data.
4462 "cmp %w[col], w8\n" // Have we finished the last column?
4463 "bge 5f\n" // If yes, just carry on without updating the column pointer.
4464 // Not finished last column: then advance to next column.
4465 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
4466 "5:\n"
4467
4468 // Set the LHS and RHS data pointers to the start of the columns just
4469 // computed.
4470 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
4471 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
4472
4473 // Load some parameters needed for the end work on current block.
4474 RUY_MAKE_ZERO(v8)
4475 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
4476 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
4477 "ins v13.h[4], w4\n" // dst_zero_point
4478 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
4479 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
4480 "dup v9.4s, w3\n" // create prod_zp_depth_vec
4481 "add x5, x4, %x[row], lsl #2\n"
4482 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
4483 "csel x4, x4, x5, eq\n"
4484
4485 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
4486 "add x5, x1, %x[row], lsl #2\n"
4487
4488 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
4489 "csel x1, x1, x5, eq\n"
4490
4491 // Load 8 bias values.
4492 "ld1 {v14.4s}, [x1], #16\n"
4493 "ld1 {v15.4s}, [x1]\n"
4494
4495 // Now that we know what LHS and RHS data the next iteration of the
4496 // main loop will need to load, we start loading the first 32 bytes of
4497 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
4498 // in the rest of the work on the current block.
4499 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4500 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4501 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4502 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4503
4504 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
4505 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4506 "add v14.4s, v14.4s, v9.4s\n"
4507 "add v15.4s, v15.4s, v9.4s\n"
4508
4509 // Perform the bias-addition (per the above, we have just folded into
4510 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
4511 "add v16.4s, v16.4s, v14.4s\n"
4512 "add v17.4s, v17.4s, v15.4s\n"
4513
4514 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
4515 "beq 401f\n"
4516 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
4517 "add x3, x3, %x[col], lsl #2\n"
4518 "ld1 {v14.4s}, [x3], #16\n"
4519 "ld1 {v15.4s}, [x3]\n"
4520 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
4521 "dup v10.4s, w5\n" // create lhs_zero_point_vec
4522 // Subtract rhs_sums * lhs_zero_point, per
4523 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4524 "mls v16.4s, v10.4s, v14.s[0]\n"
4525 "mls v17.4s, v10.4s, v14.s[0]\n"
4526 "401:\n"
4527
4528 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
4529 "beq 402f\n"
4530 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
4531 "add x2, x2, %x[row], lsl #2\n"
4532 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
4533 // Load 4 lhs_sums values.
4534 "ld1 {v11.4s}, [x2], #16\n"
4535 "ld1 {v12.4s}, [x2]\n"
4536 "ins v13.s[1], w5\n" // rhs_zero_point
4537 // Compute lhs_sums * rhs_zero_point.
4538 "mul v11.4s, v11.4s, v13.s[1]\n"
4539 "mul v12.4s, v12.4s, v13.s[1]\n"
4540 // Subtract lhs_sums * rhs_zero_point, per
4541 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4542 "sub v16.4s, v16.4s, v11.4s\n"
4543 "sub v17.4s, v17.4s, v12.4s\n"
4544
4545 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
4546 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
4547
4548 "402:\n"
4549
4550 // At this point we have computed the final int32 values. Now we
4551 // start down-quantizing them to obtain the final 8bit values from them.
4552
4553 // As part of this down-quantization, our int32 values will be
4554 // multiplied by a multiplier that has a fixed-point component and an
4555 // exponent component.
4556
4557 //Load the exponent part of the multiplier.
4558 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
4559 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
4560 "add x5, x1, %x[row], lsl #2\n"
4561 "csel x1, x1, x5, eq\n"
4562
4563 "ldr q9, [x1]\n"
4564 "ldr q10, [x1, #16]\n"
4565
4566 "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
4567 "beq 403f\n"
4568 "smax v11.4s, v9.4s, v8.4s\n"
4569 "smax v12.4s, v10.4s, v8.4s\n"
4570 "sshl v16.4s, v16.4s, v11.4s\n"
4571 "sshl v17.4s, v17.4s, v12.4s\n"
4572 "403:\n"
4573
4574 "ldr q14, [x4]\n" // multiplier_fixedpoint
4575 "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
4576
4577 "smin v11.4s, v9.4s, v8.4s\n"
4578 "smin v12.4s, v10.4s, v8.4s\n"
4579
4580 // Apply the fixed-point part of the multiplier.
4581 "sqrdmulh v16.4s, v16.4s, v14.4s\n"
4582 "sqrdmulh v17.4s, v17.4s, v15.4s\n"
4583
4584 // We have some rounding division-by-power-of-two to do. This should
4585 // always use "round to nearest". We allow for some
4586 // freedom in how ties are broken, to strike a good compromise of
4587 // performance on given hardware vs. perfect agreement of results
4588 // across hardware.
4589 //
4590 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
4591 // defined tie-breaks to help performance. On NEON, this means that we
4592 // can just use the NEON rounding instructions, such as srshl. They
4593 // happen to be breaking ties upward.
4594 //
4595 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
4596 // break-ties-away-from zero, as described in Appendix B of
4597 // https://arxiv.org/pdf/1712.05877.pdf
4598 // When we wrote that, we thought that that would be better unbiased
4599 // than the NEON upwards tie-breaks, and we had observed some
4600 // improvement on some model. However, that is only more unbiased for
4601 // data centered at zero, which was likely the case in that model,
4602 // but is not always the case. If we wanted something more consistently
4603 // unbiased then we should try breaking ties toward-nearest-even.
4604 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
4605 // Fix up values to be right-shifted, so that the (round to nearest,
4606 // break ties upward) behavior of srshl applied to these fixed-up
4607 // values, produces the same result as the desired (round to nearest,
4608 // break ties away from zero) behavior on the original values.
4609 "and v8.16b, v16.16b, v11.16b\n"
4610 "and v9.16b, v17.16b, v12.16b\n"
4611 "sshr v8.4s, v8.4s, #31\n"
4612 "sshr v9.4s, v9.4s, #31\n"
4613 "sqadd v16.4s, v16.4s, v8.4s\n"
4614 "sqadd v17.4s, v17.4s, v9.4s\n"
4615
4616 #endif
4617 // At this point we have reduced the problem of correctly implementing
4618 // rounding divide-by-power-of-two, to what the SRSHL instruction can
4619 // do.
4620 "srshl v16.4s, v16.4s, v11.4s\n"
4621 "srshl v17.4s, v17.4s, v12.4s\n"
4622
4623 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
4624 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
4625 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
4626 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
4627
4628 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
4629
4630 // Cast-and-saturate from int32 to int16
4631 "sqxtn v16.4h, v16.4s\n"
4632 "sqxtn2 v16.8h, v17.4s\n"
4633 // All data in v16 at this point.
4634
4635 // Add the destination zero point
4636 "dup v14.8h, v13.h[4]\n"
4637 "add v16.8h, v16.8h, v14.8h\n"
4638
4639 // Cast-and-saturate from int16 to uint8, leaving all data in the
4640 // lower half of v16.
4641 "sqxtun v16.8b, v16.8h\n"
4642
4643 // Load the clamp_min, clamp_max bounds
4644 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4645 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4646 "dup v14.16b, w2\n" // clamp_min
4647 "dup v15.16b, w3\n" // clamp_max
4648
4649 // Apply the clamp_min bound
4650 "umax v16.16b, v16.16b, v14.16b\n"
4651
4652 // Apply the clamp_max bound
4653 "umin v16.16b, v16.16b, v15.16b\n"
4654
4655 // Make it so that all of the final 8bit values are stored in the
4656 // first 64bits of 128bit NEON registers, so they can be stored
4657 // by 64bit st1 store instructions with byte alignment.
4658 "dup d20, v16.d[1]\n"
4659
4660 // Compute how much of the 8x1 block of destination 8bit values that
4661 // we have computed, fit in the destination matrix. Typically, all of
4662 // it fits, but when the destination matrix shape is not a multiple
4663 // of 8x1, there are some 8x1 blocks along the boundaries that do
4664 // not fit entirely.
4665 "sub w1, %w[dst_rows], %w[row]\n"
4666 "sub w2, %w[dst_cols], %w[col]\n"
4667 "mov w3, #8\n"
4668 "cmp w1, #8\n"
4669 // Compute w1 = how many rows of the 8x1 block fit
4670 "csel w1, w1, w3, le\n"
4671 "cmp w2, #8\n"
4672
4673 // Test if w1==8, i.e. if all of the 8x1 block fits.
4674 "cmp w1, w3\n"
4675 // Yes, all of the 8x1 block fits, go to fast path.
4676 "beq 30f\n"
4677 // Not all of the 8x1 block fits.
4678 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4679 "mov x3, %[dst_tmp_buf]\n"
4680 "mov x4, #8\n"
4681 "b 31f\n"
4682 "30:\n"
4683 // Yes, all of the 8x1 block fits.
4684 // Set (x3 address, x4 stride) to write directly to destination matrix.
4685 "mov x3, %[dst_ptr]\n"
4686 "mov x4, x11\n"
4687 "31:\n"
4688
4689 // Write our 8bit values to the destination
4690 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4691 "st1 {v16.8b}, [x3]\n"
4692 RUY_MAKE_ZERO(v16)
4693 RUY_MAKE_ZERO(v17)
4694
4695 // For the next block: perform the first few multiply-adds on the data
4696 // that we have already loaded.
4697 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4698
4699 // If all of the 8x8 block fits, we just finished writing it to the
4700 // destination, so we skip the next part.
4701 "beq 41f\n"
4702 // Not all of the 8x8 block fits in the destination matrix. We just
4703 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4704 // it to copy into the destination matrix the part that fits.
4705 "mov x3, %[dst_tmp_buf]\n"
4706 "mov x4, %[dst_ptr]\n"
4707 "mov w6, #0\n"
4708 "50:\n"
4709 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4710 "mov w5, #0\n"
4711 "51:\n"
4712 "ldrb w7, [x3, w5, uxtw]\n"
4713 "strb w7, [x4, w5, uxtw]\n"
4714 "add w5, w5, #1\n"
4715 "cmp w5, w1\n"
4716 "blt 51b\n"
4717 "41:\n"
4718 "add %[dst_ptr], %[dst_ptr], #8\n"
4719 // At this point we have completely finished writing values to the
4720 // destination matrix for the current block.
4721
4722 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4723
4724 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
4725
4726 // Cast-and-saturate from int32 to int16
4727 "sqxtn v16.4h, v16.4s\n"
4728 "sqxtn2 v16.8h, v17.4s\n"
4729
4730
4731 // Add the destination zero point
4732 "dup v14.8h, v13.h[4]\n"
4733 "add v16.8h, v16.8h, v14.8h\n"
4734
4735 // Cast-and-saturate from int16 to uint8
4736 "sqxtn v16.8b, v16.8h\n"
4737
4738 // Load the clamp_min, clamp_max bounds
4739 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4740 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4741 "dup v14.16b, w2\n" // clamp_min
4742 "dup v15.16b, w3\n" // clamp_max
4743
4744 // Apply the clamp_min bound
4745 "smax v16.16b, v16.16b, v14.16b\n"
4746
4747 // Apply the clamp_max bound
4748 "smin v16.16b, v16.16b, v15.16b\n"
4749
4750 // Make it so that all of the final 8bit values are stored in the
4751 // first 64bits of 128bit NEON registers, so they can be stored
4752 // by 64bit st1 store instructions with byte alignment.
4753 "dup d20, v16.d[1]\n"
4754
4755 // Compute how much of the 8x1 block of destination 8bit values that
4756 // we have computed, fit in the destination matrix. Typically, all of
4757 // it fits, but when the destination matrix shape is not a multiple
4758 // of 8x8, there are some 8x8 blocks along the boundaries that do
4759 // not fit entirely.
4760 "sub w1, %w[dst_rows], %w[row]\n"
4761 "sub w2, %w[dst_cols], %w[col]\n"
4762 "mov w3, #8\n"
4763 "cmp w1, #8\n"
4764 // Compute w1 = how many rows of the 8x1 block fit
4765 "csel w1, w1, w3, le\n"
4766 "cmp w2, #8\n"
4767
4768 // Test if w1==8, i.e. if all of the 8x1 block fits.
4769 "cmp w1, w3\n"
4770 // Yes, all of the 8x1 block fits, go to fast path.
4771 "beq 130f\n"
4772 // Not all of the 8x1 block fits.
4773 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4774 "mov x3, %[dst_tmp_buf]\n"
4775 "mov x4, #8\n"
4776 "b 131f\n"
4777 "130:\n"
4778 // Yes, all of the 8x8 block fits.
4779 // Set (x3 address, x4 stride) to write directly to destination matrix.
4780 "mov x3, %[dst_ptr]\n"
4781 "mov x4, x11\n"
4782 "131:\n"
4783
4784 // Write our 8bit values to the destination
4785 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4786 "st1 {v16.8b}, [x3]\n"
4787 RUY_MAKE_ZERO(v16)
4788 RUY_MAKE_ZERO(v17)
4789
4790 // For the next block: perform the first few multiply-adds on the data
4791 // that we have already loaded.
4792 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4793
4794 // If all of the 8x8 block fits, we just finished writing it to the
4795 // destination, so we skip the next part.
4796 "beq 141f\n"
4797 // Not all of the 8x8 block fits in the destination matrix. We just
4798 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4799 // it to copy into the destination matrix the part that fits.
4800 "mov x3, %[dst_tmp_buf]\n"
4801 "mov x4, %[dst_ptr]\n"
4802 "mov w6, #0\n"
4803 "150:\n"
4804 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4805 "mov w5, #0\n"
4806 "151:\n"
4807 "ldrb w7, [x3, w5, uxtw]\n"
4808 "strb w7, [x4, w5, uxtw]\n"
4809 "add w5, w5, #1\n"
4810 "cmp w5, w1\n"
4811 "blt 151b\n"
4812 "141:\n"
4813 "add %[dst_ptr], %[dst_ptr], #8\n"
4814 // At this point we have completely finished writing values to the
4815 // destination matrix for the current block.
4816
4817 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4818
4819 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
4820
4821 // Add the destination zero point
4822 "dup v14.8h, v13.h[4]\n"
4823 "saddw v16.4s, v16.4s, v14.4h\n"
4824 "saddw v17.4s, v17.4s, v14.4h\n"
4825
4826 // Cast-and-saturate from int32 to int16
4827 "sqxtn v16.4h, v16.4s\n"
4828 "sqxtn2 v16.8h, v17.4s\n"
4829
4830 // Load the clamp_min, clamp_max bounds
4831 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4832 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4833 "dup v14.8h, w2\n" // clamp_min
4834 "dup v15.8h, w3\n" // clamp_max
4835
4836 // Apply the clamp_min bound
4837 "smax v16.8h, v16.8h, v14.8h\n"
4838 // Apply the clamp_max bound
4839 "smin v16.8h, v16.8h, v15.8h\n"
4840
4841 // Compute how much of the 8x1 block of destination 16bit values that
4842 // we have computed, fit in the destination matrix. Typically, all of
4843 // it fits, but when the destination matrix shape is not a multiple
4844 // of 8x8, there are some 8x1 blocks along the boundaries that do
4845 // not fit entirely.
4846 "sub w1, %w[dst_rows], %w[row]\n"
4847 "sub w2, %w[dst_cols], %w[col]\n"
4848 "mov w3, #8\n"
4849 "cmp w1, #8\n"
4850 // Compute w1 = how many rows of the 8x1 block fit
4851 "csel w1, w1, w3, le\n"
4852 "cmp w2, #8\n"
4853
4854 // Test if w1==8, i.e. if all of the 8x8 block fits.
4855 "cmp w1, w3\n"
4856 // Yes, all of the 8x1 block fits, go to fast path.
4857 "beq 230f\n"
4858 // Not all of the 8x1 block fits.
4859 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4860 "mov x3, %[dst_tmp_buf]\n"
4861 "mov x4, #16\n"
4862 "b 231f\n"
4863 "230:\n"
4864 // Yes, all of the 8x1 block fits.
4865 // Set (x3 address, x4 stride) to write directly to destination matrix.
4866 "mov x3, %[dst_ptr]\n"
4867 "mov x4, x11\n"
4868 "231:\n"
4869
4870 // Write our 16bit values to the destination
4871 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4872 "st1 {v16.8h}, [x3]\n"
4873 RUY_MAKE_ZERO(v16)
4874 RUY_MAKE_ZERO(v17)
4875
4876 // For the next block: perform the first few multiply-adds on the data
4877 // that we have already loaded.
4878 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4879
4880 // If all of the 8x1 block fits, we just finished writing it to the
4881 // destination, so we skip the next part.
4882 "beq 241f\n"
4883 // Not all of the 8x1 block fits in the destination matrix. We just
4884 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4885 // it to copy into the destination matrix the part that fits.
4886 "mov x3, %[dst_tmp_buf]\n"
4887 "mov x4, %[dst_ptr]\n"
4888 "mov w6, #0\n"
4889 "250:\n"
4890 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4891 "mov w5, #0\n"
4892 "251:\n"
4893 "ldrsh w7, [x3, x5, lsl #1]\n"
4894 "strh w7, [x4, x5, lsl #1]\n"
4895 "add w5, w5, #1\n"
4896 "cmp w5, w1\n"
4897 "blt 251b\n"
4898 "241:\n"
4899 "add %[dst_ptr], %[dst_ptr], #16\n"
4900 // At this point we have completely finished writing values to the
4901 // destination matrix for the current block.
4902
4903 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4904
4905 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
4906
4907 // Since the store type is the same as the accum type, no need for
4908 // downcast. There's also no need for clamp by min/max.
4909
4910 // Compute how much of the 8x1 block of destination 32 bit values that
4911 // we have computed, fit in the destination matrix. Typically, all of
4912 // it fits, but when the destination matrix shape is not a multiple
4913 // of 8x1, there are some 8x1 blocks along the boundaries that do
4914 // not fit entirely.
4915 "sub w1, %w[dst_rows], %w[row]\n"
4916 "sub w2, %w[dst_cols], %w[col]\n"
4917 "mov w3, #8\n"
4918 "cmp w1, #8\n"
4919 // Compute w1 = how many rows of the 8x1 block fit
4920 "csel w1, w1, w3, le\n"
4921 "cmp w2, #8\n"
4922 // Compute w1 = how many rows of the 8x8 block fit
4923 "csel w2, w2, w3, le\n"
4924
4925 // Test if w1==8, i.e. if all of the 8x8 block fits.
4926 "cmp w1, w3\n"
4927 // Yes, all of the 8x1 block fits, go to fast path.
4928 "beq 330f\n"
4929 // Not all of the 8x1 block fits.
4930 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4931 "mov x3, %[dst_tmp_buf]\n"
4932 "mov x4, #16\n"
4933
4934 // Write our 32bit values to the destination described by
4935 // (x3 address, x4 stride).
4936 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4937 "st1 {v16.4s}, [x3], x4\n"
4938 RUY_MAKE_ZERO(v16)
4939 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4940 "st1 {v17.4s}, [x3], x4\n"
4941 RUY_MAKE_ZERO(v17)
4942
4943 "b 331f\n"
4944
4945 "330:\n"
4946 // Yes, all of the 8x1 block fits.
4947 // Set (x3 address, x4 stride) to write directly to destination matrix.
4948 "mov x4, %[dst_ptr]\n"
4949 "mov x3, x4\n"
4950
4951 // Write our 32bit values to the destination described by
4952 // (x3 address, x4 stride).
4953 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4954 "st1 {v16.4s, v17.4s}, [x3], #32\n"
4955 RUY_MAKE_ZERO(v16)
4956 RUY_MAKE_ZERO(v17)
4957
4958 "331:\n"
4959
4960 // For the next block: perform the first few multiply-adds on the data
4961 // that we have already loaded.
4962 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4963
4964 // If all of the 8x8 block fits, we just finished writing it to the
4965 // destination, so we skip the next part.
4966 "beq 341f\n"
4967
4968 // Not all of the 8x8 block fits in the destination matrix. We just
4969 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4970 // it to copy into the destination matrix the part that fits.
4971 "mov x3, %[dst_tmp_buf]\n"
4972 "mov x4, %[dst_ptr]\n"
4973 "mov w6, #0\n"
4974 "350:\n"
4975 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4976 "mov w5, #0\n"
4977 "351:\n"
4978 "ldr w7, [x3, x5, lsl #2]\n"
4979 "str w7, [x4, x5, lsl #2]\n"
4980 "add w5, w5, #1\n"
4981 "cmp w5, w1\n"
4982 "blt 351b\n"
4983 "341:\n"
4984 "add %[dst_ptr], %[dst_ptr], #32\n"
4985 // At this point we have completely finished writing values to the
4986 // destination matrix for the current block.
4987
4988 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
4989
4990 // Reload some params --- we had used x5 -- x7 for a few other things
4991 // since the last time we had loaded them.
4992 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
4993 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
4994 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
4995
4996 // Move to the next block of the destination matrix, for the next iter
4997 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
4998 // been updated earlier.
4999 // Have we reached the end row?
5000 "cmp %w[row], w7\n"
5001 "beq 20f\n" // yes, end row.
5002 // Not end row. Move to the next row.
5003 "add %w[row], %w[row], #8\n"
5004 "b 21f\n"
5005 "20:\n"
5006 // Was already at end row.
5007 "mov %w[row], w6\n" // Move back to first row.
5008 "add %w[col], %w[col], #8\n" // Move to the next column.
5009 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
5010 "mov %[dst_ptr], %[dst_col_ptr]\n"
5011 "21:\n"
5012
5013 // Main loop exit condition: have we hit the end column?
5014 "cmp %w[col], w8\n"
5015
5016 // w1 is the number of levels of depth that we have already loaded
5017 // LHS and RHS data for. Corresponding to the initial ld1 instructions
5018 // above, this is currently 4.
5019 "mov w1, #4\n"
5020
5021 "ble 1b\n"
5022
5023 // clang-format on
5024
5025 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
5026 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
5027 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
5028 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
5029 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
5030 [dst_type_id] "r"(params.dst_type_id)
5031 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
5032 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
5033 "v13", "v14", "v15", "v16", "v17");
5034 }
5035
5036 // Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order
5037 // CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1,
5038 // since these are 64-bit and support dotprod.
5039 //
5040 // While this kernel does not have a direct equivalent in gemmlowp, it was
5041 // developed based on insights that David Mansell at ARM shared with their
5042 // contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful
5043 // comments. Specifically, see this comment about tuning for Cortex-A55r1:
5044 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8,8> & params)5045 void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) {
5046 profiler::ScopeLabel label(
5047 "Kernel (kNeonDotprod, optimized for in-order cores)");
5048
5049 CheckOffsetsInKernelParams8bit(params);
5050
5051 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
5052 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
5053 const std::int8_t* lhs_ptr = lhs_col_ptr;
5054 const std::int8_t* rhs_ptr = rhs_col_ptr;
5055 void* dst_col_ptr = params.dst_base_ptr;
5056 void* dst_ptr = dst_col_ptr;
5057 int row = params.start_row;
5058 int col = params.start_col;
5059
5060 // The asm kernel below has the following NEON register allocation:
5061 //
5062 // v16 -- v31 are int32 accumulators.
5063 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
5064 // RHS.
5065 //
5066 // int8 RHS 4x8 block
5067 // /-----------------------------------------\
5068 // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
5069 // | ... ... |
5070 // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
5071 // \-----------------------------------------/
5072 // int8 LHS 8x4 block
5073 // /---------------------\ /-----------------------------------------\
5074 // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
5075 // | ... ... | | ... ... |
5076 // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
5077 // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
5078 // | ... ... | | ... ... |
5079 // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
5080 // \---------------------/ \-----------------------------------------/
5081 // int32 accumulators 8x8 block
5082 //
5083 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
5084 // we did not observe a benefit of such partial unrolling on in-order CPUs.
5085 //
5086 // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for
5087 // the post-accumulation part of the kernel.
5088 asm volatile(
5089 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
5090
5091 // clang-format off
5092
5093 // Load some parameters into registers.
5094 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
5095 RUY_MAKE_ZERO(v16)
5096 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
5097 RUY_MAKE_ZERO(v17)
5098 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
5099 RUY_MAKE_ZERO(v18)
5100 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
5101 RUY_MAKE_ZERO(v19)
5102 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
5103 RUY_MAKE_ZERO(v20)
5104 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
5105 RUY_MAKE_ZERO(v21)
5106 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
5107 RUY_MAKE_ZERO(v22)
5108 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
5109
5110 // Load the first 32 bytes of LHS and RHS data.
5111 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
5112 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
5113 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
5114 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
5115
5116 // Clear accumulators.
5117 RUY_MAKE_ZERO(v23)
5118 RUY_MAKE_ZERO(v24)
5119 RUY_MAKE_ZERO(v25)
5120 RUY_MAKE_ZERO(v26)
5121 RUY_MAKE_ZERO(v27)
5122 // Perform the first few multiply-adds on the data that we have already
5123 // loaded.
5124 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5125 RUY_MAKE_ZERO(v28)
5126 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5127 RUY_MAKE_ZERO(v29)
5128 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5129 RUY_MAKE_ZERO(v30)
5130 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5131 RUY_MAKE_ZERO(v31)
5132
5133
5134 "1:\n"
5135
5136 "add x5, %[lhs_ptr], x12, lsl #3\n"
5137 "sub x5, x5, #32\n"
5138 "cmp %[lhs_ptr], x5\n"
5139
5140 "beq 79f\n"
5141
5142 // Main accumulation loop
5143 "2:\n"
5144 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
5145 "ldr x1, [%[lhs_ptr], #8]\n"
5146 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
5147 "ldr x3, [%[rhs_ptr], #8]\n"
5148 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
5149 "ldr x4, [%[rhs_ptr], #24]\n"
5150 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
5151 "ldr d0, [%[lhs_ptr], #0]\n"
5152 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
5153 "ins v0.d[1], x1\n"
5154 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
5155 "ldr x2, [%[lhs_ptr], #24]\n"
5156 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
5157 "add %[lhs_ptr], %[lhs_ptr], #32\n"
5158 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
5159 "ldr d2, [%[rhs_ptr], #0]\n"
5160 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
5161 "ins v2.d[1], x3\n"
5162 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
5163 "cmp %[lhs_ptr], x5\n"
5164 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
5165 "add %[rhs_ptr], %[rhs_ptr], #32\n"
5166 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
5167 "ldr d3, [%[rhs_ptr], #-16]\n"
5168 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5169 "ldr d1, [%[lhs_ptr], #-16]\n"
5170 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5171 "ins v3.d[1], x4\n"
5172 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5173 "ins v1.d[1], x2\n"
5174 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5175 "blt 2b\n"
5176
5177 // Last accumulation steps, nothing left to load.
5178 "79:\n"
5179 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
5180 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
5181 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
5182 "cmp %w[row], w7\n" // Have we finished the last row?
5183 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
5184 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
5185 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
5186 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
5187 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
5188 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
5189 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
5190 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
5191 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
5192 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
5193
5194 // End of accumulation. The registers v16 -- v31 contain the final
5195 // int32 accumulator values of the current 8x8 destination block.
5196 // We now have to compute the final 8-bit values from these int32
5197 // accumulators, and advance to the next 8x8 block. We intertwine
5198 // these two aspects whenever possible for optimal pipelining, both
5199 // at the data flow level (prefetch data for next block as early as
5200 // possible) and instruction pipelining level (some of the next-block
5201 // work can dual-issue with some of the final work on the current
5202 // block).
5203
5204 // Logic to advance to the next block in preparation for the next
5205 // iteration of the main loop. For now, we only want to compute
5206 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
5207 // not yet ready to update the values of row and col, as we still need
5208 // the current values for the rest of the work on the current block.
5209
5210 "bge 4f\n" // If finished last row, go to 4
5211 // Not finished last row: then advance to next row.
5212 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
5213 "b 5f\n"
5214 "4:\n" // Finished last row...
5215 "mov %[lhs_col_ptr], x5\n" // Go back to first row
5216 // Now we need to advance to the next column. If we already
5217 // finished the last column, then in principle we are done, however
5218 // we can't just return here, as we need to allow the end work of the
5219 // current block to complete. The good news is that at this point it
5220 // doesn't matter what data we load for the next column, since
5221 // we will exit from the main loop below before actually storing
5222 // anything computed from that data.
5223 "cmp %w[col], w8\n" // Have we finished the last column?
5224 "bge 5f\n" // If yes, just carry on without updating the column pointer.
5225 // Not finished last column: then advance to next column.
5226 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
5227 "5:\n"
5228
5229 // Set the LHS and RHS data pointers to the start of the columns just
5230 // computed.
5231 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
5232 // Load some parameters needed for the end work on current block.
5233 RUY_MAKE_ZERO(v8)
5234 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
5235 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
5236 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
5237 "ins v13.h[4], w4\n" // dst_zero_point
5238 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
5239 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
5240 "dup v9.4s, w3\n" // create prod_zp_depth_vec
5241 "add x5, x4, %x[row], lsl #2\n"
5242 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
5243 "csel x4, x4, x5, eq\n"
5244
5245 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
5246 "add x5, x1, %x[row], lsl #2\n"
5247 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
5248 "csel x1, x1, x5, eq\n"
5249
5250 // Load 8 bias values.
5251 "ld1 {v14.2s}, [x1], #8\n"
5252 "ldr x5, [x1], #8\n"
5253 "ins v14.d[1], x5\n"
5254 "ld1 {v15.2s}, [x1], #8\n"
5255 "ldr x5, [x1], #8\n"
5256 "ins v15.d[1], x5\n"
5257
5258 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
5259 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5260 "add v14.4s, v14.4s, v9.4s\n"
5261 "add v15.4s, v15.4s, v9.4s\n"
5262 // Perform the bias-addition (per the above, we have just folded into
5263 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
5264 "add v16.4s, v16.4s, v14.4s\n"
5265 "add v17.4s, v17.4s, v15.4s\n"
5266 "add v18.4s, v18.4s, v14.4s\n"
5267 "add v19.4s, v19.4s, v15.4s\n"
5268 "add v20.4s, v20.4s, v14.4s\n"
5269 "add v21.4s, v21.4s, v15.4s\n"
5270 "add v22.4s, v22.4s, v14.4s\n"
5271 "add v23.4s, v23.4s, v15.4s\n"
5272 "add v24.4s, v24.4s, v14.4s\n"
5273 "add v25.4s, v25.4s, v15.4s\n"
5274 "add v26.4s, v26.4s, v14.4s\n"
5275 "add v27.4s, v27.4s, v15.4s\n"
5276 "add v28.4s, v28.4s, v14.4s\n"
5277 "add v29.4s, v29.4s, v15.4s\n"
5278 "add v30.4s, v30.4s, v14.4s\n"
5279 "add v31.4s, v31.4s, v15.4s\n"
5280
5281 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
5282 "beq 401f\n"
5283 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
5284 "add x3, x3, %x[col], lsl #2\n"
5285 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
5286 "dup v10.4s, w5\n" // create lhs_zero_point_vec
5287 // Load 8 rhs_sums values.
5288 "ld1 {v14.2s}, [x3], #8\n"
5289 "ldr x7, [x3], #8\n"
5290 "ld1 {v15.2s}, [x3], #8\n"
5291 "ins v14.d[1], x7\n"
5292 "ldr x7, [x3], #8\n"
5293 "ins v15.d[1], x7\n"
5294 // Subtract rhs_sums * lhs_zero_point, per
5295 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5296 "mls v16.4s, v10.4s, v14.s[0]\n"
5297 "mls v17.4s, v10.4s, v14.s[0]\n"
5298 "mls v18.4s, v10.4s, v14.s[1]\n"
5299 "mls v19.4s, v10.4s, v14.s[1]\n"
5300 "mls v20.4s, v10.4s, v14.s[2]\n"
5301 "mls v21.4s, v10.4s, v14.s[2]\n"
5302 "mls v22.4s, v10.4s, v14.s[3]\n"
5303 "mls v23.4s, v10.4s, v14.s[3]\n"
5304 "mls v24.4s, v10.4s, v15.s[0]\n"
5305 "mls v25.4s, v10.4s, v15.s[0]\n"
5306 "mls v26.4s, v10.4s, v15.s[1]\n"
5307 "mls v27.4s, v10.4s, v15.s[1]\n"
5308 "mls v28.4s, v10.4s, v15.s[2]\n"
5309 "mls v29.4s, v10.4s, v15.s[2]\n"
5310 "mls v30.4s, v10.4s, v15.s[3]\n"
5311 "mls v31.4s, v10.4s, v15.s[3]\n"
5312 "401:\n"
5313
5314 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
5315 "beq 402f\n"
5316 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
5317 "add x2, x2, %x[row], lsl #2\n"
5318 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
5319 "ins v13.s[1], w5\n" // rhs_zero_point
5320 // Load 8 lhs_sums values.
5321 "ld1 {v11.2s}, [x2], #8\n"
5322 "ldr x6, [x2], #8\n"
5323 "ins v11.d[1], x6\n"
5324 "ld1 {v12.2s}, [x2], #8\n"
5325 "ldr x6, [x2], #8\n"
5326 "ins v12.d[1], x6\n"
5327 // Compute lhs_sums * rhs_zero_point.
5328 "mul v11.4s, v11.4s, v13.s[1]\n"
5329 "mul v12.4s, v12.4s, v13.s[1]\n"
5330 // Subtract lhs_sums * rhs_zero_point, per
5331 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5332 "sub v16.4s, v16.4s, v11.4s\n"
5333 "sub v17.4s, v17.4s, v12.4s\n"
5334 "sub v18.4s, v18.4s, v11.4s\n"
5335 "sub v19.4s, v19.4s, v12.4s\n"
5336 "sub v20.4s, v20.4s, v11.4s\n"
5337 "sub v21.4s, v21.4s, v12.4s\n"
5338 "sub v22.4s, v22.4s, v11.4s\n"
5339 "sub v23.4s, v23.4s, v12.4s\n"
5340 "sub v24.4s, v24.4s, v11.4s\n"
5341 "sub v25.4s, v25.4s, v12.4s\n"
5342 "sub v26.4s, v26.4s, v11.4s\n"
5343 "sub v27.4s, v27.4s, v12.4s\n"
5344 "sub v28.4s, v28.4s, v11.4s\n"
5345 "sub v29.4s, v29.4s, v12.4s\n"
5346 "sub v30.4s, v30.4s, v11.4s\n"
5347 "sub v31.4s, v31.4s, v12.4s\n"
5348
5349 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
5350 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
5351
5352 "402:\n"
5353
5354 // At this point we have computed the final int32 values. Now we
5355 // start down-quantizing them to obtain the final 8bit values from them.
5356
5357 // As part of this down-quantization, our int32 values will be
5358 // multiplied by a multiplier that has a fixed-point component and an
5359 // exponent component.
5360
5361 //Load the exponent part of the multiplier.
5362 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
5363 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
5364 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
5365 "add x5, x1, %x[row], lsl #2\n"
5366 "csel x1, x1, x5, eq\n"
5367
5368 "ldr q9, [x1]\n"
5369 "ldr q10, [x1, #16]\n"
5370
5371 "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
5372 "beq 403f\n"
5373 "smax v11.4s, v9.4s, v8.4s\n"
5374 "smax v12.4s, v10.4s, v8.4s\n"
5375 "sshl v16.4s, v16.4s, v11.4s\n"
5376 "sshl v17.4s, v17.4s, v12.4s\n"
5377 "sshl v18.4s, v18.4s, v11.4s\n"
5378 "sshl v19.4s, v19.4s, v12.4s\n"
5379 "sshl v20.4s, v20.4s, v11.4s\n"
5380 "sshl v21.4s, v21.4s, v12.4s\n"
5381 "sshl v22.4s, v22.4s, v11.4s\n"
5382 "sshl v23.4s, v23.4s, v12.4s\n"
5383 "sshl v24.4s, v24.4s, v11.4s\n"
5384 "sshl v25.4s, v25.4s, v12.4s\n"
5385 "sshl v26.4s, v26.4s, v11.4s\n"
5386 "sshl v27.4s, v27.4s, v12.4s\n"
5387 "sshl v28.4s, v28.4s, v11.4s\n"
5388 "sshl v29.4s, v29.4s, v12.4s\n"
5389 "sshl v30.4s, v30.4s, v11.4s\n"
5390 "sshl v31.4s, v31.4s, v12.4s\n"
5391 "403:\n"
5392
5393 "ldr q14, [x4]\n" // multiplier_fixedpoint
5394 "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
5395
5396 "smin v11.4s, v9.4s, v8.4s\n"
5397 "smin v12.4s, v10.4s, v8.4s\n"
5398
5399 // Apply the fixed-point part of the multiplier.
5400 //
5401 // ... and, interleaved into that:
5402 // Now that we know what LHS and RHS data the next iteration of the
5403 // main loop will need to load, we start loading the first 32 bytes of
5404 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
5405 // in the rest of the work on the current block.
5406 "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
5407 "sqrdmulh v16.4s, v16.4s, v14.4s\n"
5408 "ldr x1, [%[lhs_ptr]], #8\n"
5409 "sqrdmulh v17.4s, v17.4s, v15.4s\n"
5410 "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
5411 "sqrdmulh v18.4s, v18.4s, v14.4s\n"
5412 "ldr x2, [%[lhs_ptr]], #8\n"
5413 "sqrdmulh v19.4s, v19.4s, v15.4s\n"
5414 "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
5415 "sqrdmulh v20.4s, v20.4s, v14.4s\n"
5416 "ldr x5, [%[rhs_ptr]], #8\n"
5417 "sqrdmulh v21.4s, v21.4s, v15.4s\n"
5418 "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
5419 "sqrdmulh v22.4s, v22.4s, v14.4s\n"
5420 "ldr x6, [%[rhs_ptr]], #8\n"
5421 "sqrdmulh v23.4s, v23.4s, v15.4s\n"
5422 "sqrdmulh v24.4s, v24.4s, v14.4s\n"
5423 "sqrdmulh v25.4s, v25.4s, v15.4s\n"
5424 "sqrdmulh v26.4s, v26.4s, v14.4s\n"
5425 "sqrdmulh v27.4s, v27.4s, v15.4s\n"
5426 "sqrdmulh v28.4s, v28.4s, v14.4s\n"
5427 "sqrdmulh v29.4s, v29.4s, v15.4s\n"
5428 "sqrdmulh v30.4s, v30.4s, v14.4s\n"
5429 "sqrdmulh v31.4s, v31.4s, v15.4s\n"
5430
5431 // We have some rounding division-by-power-of-two to do. This should
5432 // always use "round to nearest". We allow for some
5433 // freedom in how ties are broken, to strike a good compromise of
5434 // performance on given hardware vs. perfect agreement of results
5435 // across hardware.
5436 //
5437 // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
5438 // defined tie-breaks to help performance. On NEON, this means that we
5439 // can just use the NEON rounding instructions, such as srshl. They
5440 // happen to be breaking ties upward.
5441 //
5442 // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
5443 // break-ties-away-from zero, as described in Appendix B of
5444 // https://arxiv.org/pdf/1712.05877.pdf
5445 // When we wrote that, we thought that that would be better unbiased
5446 // than the NEON upwards tie-breaks, and we had observed some
5447 // improvement on some model. However, that is only more unbiased for
5448 // data centered at zero, which was likely the case in that model,
5449 // but is not always the case. If we wanted something more consistently
5450 // unbiased then we should try breaking ties toward-nearest-even.
5451 #if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
5452 // Fix up values to be right-shifted, so that the (round to nearest,
5453 // break ties upward) behavior of srshl applied to these fixed-up
5454 // values, produces the same result as the desired (round to nearest,
5455 // break ties away from zero) behavior on the original values.
5456 "and v8.16b, v16.16b, v11.16b\n"
5457 "and v9.16b, v17.16b, v12.16b\n"
5458 "and v14.16b, v18.16b, v11.16b\n"
5459 "and v15.16b, v19.16b, v12.16b\n"
5460 "sshr v8.4s, v8.4s, #31\n"
5461 "sshr v9.4s, v9.4s, #31\n"
5462 "sshr v14.4s, v14.4s, #31\n"
5463 "sshr v15.4s, v15.4s, #31\n"
5464 "sqadd v16.4s, v16.4s, v8.4s\n"
5465 "sqadd v17.4s, v17.4s, v9.4s\n"
5466 "sqadd v18.4s, v18.4s, v14.4s\n"
5467 "sqadd v19.4s, v19.4s, v15.4s\n"
5468 "and v8.16b, v20.16b, v11.16b\n"
5469 "and v9.16b, v21.16b, v12.16b\n"
5470 "and v14.16b, v22.16b, v11.16b\n"
5471 "and v15.16b, v23.16b, v12.16b\n"
5472 "sshr v8.4s, v8.4s, #31\n"
5473 "sshr v9.4s, v9.4s, #31\n"
5474 "sshr v14.4s, v14.4s, #31\n"
5475 "sshr v15.4s, v15.4s, #31\n"
5476 "sqadd v20.4s, v20.4s, v8.4s\n"
5477 "sqadd v21.4s, v21.4s, v9.4s\n"
5478 "sqadd v22.4s, v22.4s, v14.4s\n"
5479 "sqadd v23.4s, v23.4s, v15.4s\n"
5480 "and v8.16b, v24.16b, v11.16b\n"
5481 "and v9.16b, v25.16b, v12.16b\n"
5482 "and v14.16b, v26.16b, v11.16b\n"
5483 "and v15.16b, v27.16b, v12.16b\n"
5484 "sshr v8.4s, v8.4s, #31\n"
5485 "sshr v9.4s, v9.4s, #31\n"
5486 "sshr v14.4s, v14.4s, #31\n"
5487 "sshr v15.4s, v15.4s, #31\n"
5488 "sqadd v24.4s, v24.4s, v8.4s\n"
5489 "sqadd v25.4s, v25.4s, v9.4s\n"
5490 "sqadd v26.4s, v26.4s, v14.4s\n"
5491 "sqadd v27.4s, v27.4s, v15.4s\n"
5492 "and v8.16b, v28.16b, v11.16b\n"
5493 "and v9.16b, v29.16b, v12.16b\n"
5494 "and v14.16b, v30.16b, v11.16b\n"
5495 "and v15.16b, v31.16b, v12.16b\n"
5496 "sshr v8.4s, v8.4s, #31\n"
5497 "sshr v9.4s, v9.4s, #31\n"
5498 "sshr v14.4s, v14.4s, #31\n"
5499 "sshr v15.4s, v15.4s, #31\n"
5500 "sqadd v28.4s, v28.4s, v8.4s\n"
5501 "sqadd v29.4s, v29.4s, v9.4s\n"
5502 "sqadd v30.4s, v30.4s, v14.4s\n"
5503 "sqadd v31.4s, v31.4s, v15.4s\n"
5504 #endif
5505 // At this point we have reduced the problem of correctly implementing
5506 // rounding divide-by-power-of-two, to what the SRSHL instruction can
5507 // do.
5508 "srshl v16.4s, v16.4s, v11.4s\n"
5509 "srshl v17.4s, v17.4s, v12.4s\n"
5510 "srshl v18.4s, v18.4s, v11.4s\n"
5511 "srshl v19.4s, v19.4s, v12.4s\n"
5512 "srshl v20.4s, v20.4s, v11.4s\n"
5513 "srshl v21.4s, v21.4s, v12.4s\n"
5514 "srshl v22.4s, v22.4s, v11.4s\n"
5515 "srshl v23.4s, v23.4s, v12.4s\n"
5516 "srshl v24.4s, v24.4s, v11.4s\n"
5517 "srshl v25.4s, v25.4s, v12.4s\n"
5518 "srshl v26.4s, v26.4s, v11.4s\n"
5519 "srshl v27.4s, v27.4s, v12.4s\n"
5520 "ins v0.d[1], x1\n"
5521 "srshl v28.4s, v28.4s, v11.4s\n"
5522 "ins v1.d[1], x2\n"
5523 "srshl v29.4s, v29.4s, v12.4s\n"
5524 "ins v2.d[1], x5\n"
5525 "srshl v30.4s, v30.4s, v11.4s\n"
5526 "ins v3.d[1], x6\n"
5527 "srshl v31.4s, v31.4s, v12.4s\n"
5528
5529 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
5530 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
5531 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
5532 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
5533
5534 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
5535
5536 // Cast-and-saturate from int32 to int16
5537 "sqxtn v16.4h, v16.4s\n"
5538 "sqxtn2 v16.8h, v17.4s\n"
5539 "sqxtn v17.4h, v18.4s\n"
5540 "sqxtn2 v17.8h, v19.4s\n"
5541 "sqxtn v18.4h, v20.4s\n"
5542 "sqxtn2 v18.8h, v21.4s\n"
5543 "sqxtn v19.4h, v22.4s\n"
5544 "sqxtn2 v19.8h, v23.4s\n"
5545 "sqxtn v20.4h, v24.4s\n"
5546 "sqxtn2 v20.8h, v25.4s\n"
5547 "sqxtn v21.4h, v26.4s\n"
5548 "sqxtn2 v21.8h, v27.4s\n"
5549 "sqxtn v22.4h, v28.4s\n"
5550 "sqxtn2 v22.8h, v29.4s\n"
5551 "sqxtn v23.4h, v30.4s\n"
5552 "sqxtn2 v23.8h, v31.4s\n"
5553
5554 // Destination zero_point
5555 "dup v14.8h, v13.h[4]\n"
5556 // At this point, v24 -- v31 aren't used anymore for the current block,
5557 // so we can start clearing these accumulators for the next block
5558 // (next iteration of the main loop).
5559 RUY_MAKE_ZERO(v24)
5560 RUY_MAKE_ZERO(v25)
5561 RUY_MAKE_ZERO(v26)
5562 RUY_MAKE_ZERO(v27)
5563 RUY_MAKE_ZERO(v28)
5564 RUY_MAKE_ZERO(v29)
5565 RUY_MAKE_ZERO(v30)
5566 RUY_MAKE_ZERO(v31)
5567
5568 // Add the destination zero point
5569 "add v16.8h, v16.8h, v14.8h\n"
5570 "add v17.8h, v17.8h, v14.8h\n"
5571 "add v18.8h, v18.8h, v14.8h\n"
5572 "add v19.8h, v19.8h, v14.8h\n"
5573 "add v20.8h, v20.8h, v14.8h\n"
5574 "add v21.8h, v21.8h, v14.8h\n"
5575 "add v22.8h, v22.8h, v14.8h\n"
5576 "add v23.8h, v23.8h, v14.8h\n"
5577
5578 // Load the clamp_min, clamp_max bounds
5579 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
5580 // Cast-and-saturate from int16 to uint8
5581 "sqxtun v16.8b, v16.8h\n"
5582 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
5583 "sqxtun2 v16.16b, v17.8h\n"
5584 "sqxtun v17.8b, v18.8h\n"
5585 "sqxtun2 v17.16b, v19.8h\n"
5586 "sqxtun v18.8b, v20.8h\n"
5587 "sqxtun2 v18.16b, v21.8h\n"
5588 "sqxtun v19.8b, v22.8h\n"
5589 "sqxtun2 v19.16b, v23.8h\n"
5590
5591 "dup v14.16b, w2\n" // clamp_min
5592 "dup v15.16b, w3\n" // clamp_max
5593
5594 // Compute how much of the 8x8 block of destination 8bit values that
5595 // we have computed, fit in the destination matrix. Typically, all of
5596 // it fits, but when the destination matrix shape is not a multiple
5597 // of 8x8, there are some 8x8 blocks along the boundaries that do
5598 // not fit entirely.
5599 "sub w1, %w[dst_rows], %w[row]\n"
5600 // Apply the clamp_min bound
5601 "umax v16.16b, v16.16b, v14.16b\n"
5602 "sub w2, %w[dst_cols], %w[col]\n"
5603 "umax v17.16b, v17.16b, v14.16b\n"
5604 "mov w3, #8\n"
5605 "umax v18.16b, v18.16b, v14.16b\n"
5606 "cmp w1, #8\n"
5607 "umax v19.16b, v19.16b, v14.16b\n"
5608 // Compute w1 = how many rows of the 8x8 block fit
5609 "csel w1, w1, w3, le\n"
5610 // Apply the clamp_max bound
5611 "umin v16.16b, v16.16b, v15.16b\n"
5612 "cmp w2, #8\n"
5613 "umin v17.16b, v17.16b, v15.16b\n"
5614 // Compute w2 = how many cols of the 8x8 block fit
5615 "csel w2, w2, w3, le\n"
5616 "umin v18.16b, v18.16b, v15.16b\n"
5617 "umin v19.16b, v19.16b, v15.16b\n"
5618
5619 // Make it so that all of the final 8bit values are stored in the
5620 // first 64bits of 128bit NEON registers, so they can be stored
5621 // by 64bit st1 store instructions with byte alignment.
5622 "dup d20, v16.d[1]\n"
5623 "dup d21, v17.d[1]\n"
5624 "dup d22, v18.d[1]\n"
5625 "dup d23, v19.d[1]\n"
5626
5627 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
5628 "cmp w1, w3\n"
5629 "ccmp w2, w3, 0, eq\n"
5630 // Yes, all of the 8x8 block fits, go to fast path.
5631 "beq 30f\n"
5632 // Not all of the 8x8 block fits.
5633 // Set (x3 address, x4 stride) to write to dst_tmp_buf
5634 "mov x3, %[dst_tmp_buf]\n"
5635 "mov x4, #8\n"
5636 "b 31f\n"
5637 "30:\n"
5638 // Yes, all of the 8x8 block fits.
5639 // Set (x3 address, x4 stride) to write directly to destination matrix.
5640 "mov x3, %[dst_ptr]\n"
5641 "mov x4, x11\n"
5642 "31:\n"
5643
5644 // Write our 8bit values to the destination described by
5645 // (x3 address, x4 stride).
5646 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5647 "st1 {v16.8b}, [x3], x4\n"
5648 RUY_MAKE_ZERO(v16)
5649 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5650 "st1 {v20.8b}, [x3], x4\n"
5651 RUY_MAKE_ZERO(v20)
5652 // For the next block: perform the first few multiply-adds on the data
5653 // that we have already loaded.
5654 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5655 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5656 "st1 {v17.8b}, [x3], x4\n"
5657 RUY_MAKE_ZERO(v17)
5658 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5659 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5660 "st1 {v21.8b}, [x3], x4\n"
5661 RUY_MAKE_ZERO(v21)
5662 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5663 "st1 {v18.8b}, [x3], x4\n"
5664 RUY_MAKE_ZERO(v18)
5665 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5666 "st1 {v22.8b}, [x3], x4\n"
5667 RUY_MAKE_ZERO(v22)
5668 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5669 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5670 "st1 {v19.8b}, [x3], x4\n"
5671 RUY_MAKE_ZERO(v19)
5672 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5673 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5674 "st1 {v23.8b}, [x3], x4\n"
5675 RUY_MAKE_ZERO(v23)
5676
5677 // If all of the 8x8 block fits, we just finished writing it to the
5678 // destination, so we skip the next part.
5679 "beq 41f\n"
5680 // Not all of the 8x8 block fits in the destination matrix. We just
5681 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
5682 // it to copy into the destination matrix the part that fits.
5683 "mov x3, %[dst_tmp_buf]\n"
5684 "mov x4, %[dst_ptr]\n"
5685 "mov w6, #0\n"
5686 "50:\n"
5687 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
5688 "mov w5, #0\n"
5689 "51:\n"
5690 "ldrb w7, [x3, w5, uxtw]\n"
5691 "strb w7, [x4, w5, uxtw]\n"
5692 "add w5, w5, #1\n"
5693 "cmp w5, w1\n"
5694 "blt 51b\n"
5695 "add w6, w6, #1\n"
5696 "add x3, x3, #8\n"
5697 "add x4, x4, x11\n"
5698 "cmp w6, w2\n"
5699 "blt 50b\n"
5700 "41:\n"
5701 "add %[dst_ptr], %[dst_ptr], #8\n"
5702
5703 // At this point we have completely finished writing values to the
5704 // destination matrix for the current block.
5705
5706 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
5707
5708 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
5709
5710 // Cast-and-saturate from int32 to int16
5711 "sqxtn v16.4h, v16.4s\n"
5712 "sqxtn2 v16.8h, v17.4s\n"
5713 "sqxtn v17.4h, v18.4s\n"
5714 "sqxtn2 v17.8h, v19.4s\n"
5715 "sqxtn v18.4h, v20.4s\n"
5716 "sqxtn2 v18.8h, v21.4s\n"
5717 "sqxtn v19.4h, v22.4s\n"
5718 "sqxtn2 v19.8h, v23.4s\n"
5719 "sqxtn v20.4h, v24.4s\n"
5720 "sqxtn2 v20.8h, v25.4s\n"
5721 "sqxtn v21.4h, v26.4s\n"
5722 "sqxtn2 v21.8h, v27.4s\n"
5723 "sqxtn v22.4h, v28.4s\n"
5724 "sqxtn2 v22.8h, v29.4s\n"
5725 "sqxtn v23.4h, v30.4s\n"
5726 "sqxtn2 v23.8h, v31.4s\n"
5727
5728 // Destination zero_point
5729 "dup v14.8h, v13.h[4]\n"
5730 // At this point, v24 -- v31 aren't used anymore for the current block,
5731 // so we can start clearing these accumulators for the next block
5732 // (next iteration of the main loop).
5733 RUY_MAKE_ZERO(v24)
5734 RUY_MAKE_ZERO(v25)
5735 RUY_MAKE_ZERO(v26)
5736 RUY_MAKE_ZERO(v27)
5737 RUY_MAKE_ZERO(v28)
5738 RUY_MAKE_ZERO(v29)
5739 RUY_MAKE_ZERO(v30)
5740 RUY_MAKE_ZERO(v31)
5741
5742 // Add the destination zero point
5743 "add v16.8h, v16.8h, v14.8h\n"
5744 "add v17.8h, v17.8h, v14.8h\n"
5745 "add v18.8h, v18.8h, v14.8h\n"
5746 "add v19.8h, v19.8h, v14.8h\n"
5747 "add v20.8h, v20.8h, v14.8h\n"
5748 "add v21.8h, v21.8h, v14.8h\n"
5749 "add v22.8h, v22.8h, v14.8h\n"
5750 "add v23.8h, v23.8h, v14.8h\n"
5751
5752 // Load the clamp_min, clamp_max bounds
5753 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
5754 // Cast-and-saturate from int16 to uint8
5755 "sqxtn v16.8b, v16.8h\n"
5756 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
5757 "sqxtn2 v16.16b, v17.8h\n"
5758 "sqxtn v17.8b, v18.8h\n"
5759 "sqxtn2 v17.16b, v19.8h\n"
5760 "sqxtn v18.8b, v20.8h\n"
5761 "sqxtn2 v18.16b, v21.8h\n"
5762 "sqxtn v19.8b, v22.8h\n"
5763 "sqxtn2 v19.16b, v23.8h\n"
5764
5765 "dup v14.16b, w2\n" // clamp_min
5766 "dup v15.16b, w3\n" // clamp_max
5767
5768 // Compute how much of the 8x8 block of destination 8bit values that
5769 // we have computed, fit in the destination matrix. Typically, all of
5770 // it fits, but when the destination matrix shape is not a multiple
5771 // of 8x8, there are some 8x8 blocks along the boundaries that do
5772 // not fit entirely.
5773 "sub w1, %w[dst_rows], %w[row]\n"
5774 // Apply the clamp_min bound
5775 "smax v16.16b, v16.16b, v14.16b\n"
5776 "sub w2, %w[dst_cols], %w[col]\n"
5777 "smax v17.16b, v17.16b, v14.16b\n"
5778 "mov w3, #8\n"
5779 "smax v18.16b, v18.16b, v14.16b\n"
5780 "cmp w1, #8\n"
5781 "smax v19.16b, v19.16b, v14.16b\n"
5782 // Compute w1 = how many rows of the 8x8 block fit
5783 "csel w1, w1, w3, le\n"
5784 // Apply the clamp_max bound
5785 "smin v16.16b, v16.16b, v15.16b\n"
5786 "cmp w2, #8\n"
5787 "smin v17.16b, v17.16b, v15.16b\n"
5788 // Compute w2 = how many cols of the 8x8 block fit
5789 "csel w2, w2, w3, le\n"
5790 "smin v18.16b, v18.16b, v15.16b\n"
5791 "smin v19.16b, v19.16b, v15.16b\n"
5792
5793 // Make it so that all of the final 8bit values are stored in the
5794 // first 64bits of 128bit NEON registers, so they can be stored
5795 // by 64bit st1 store instructions with byte alignment.
5796 "dup d20, v16.d[1]\n"
5797 "dup d21, v17.d[1]\n"
5798 "dup d22, v18.d[1]\n"
5799 "dup d23, v19.d[1]\n"
5800
5801 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
5802 "cmp w1, w3\n"
5803 "ccmp w2, w3, 0, eq\n"
5804 // Yes, all of the 8x8 block fits, go to fast path.
5805 "beq 130f\n"
5806 // Not all of the 8x8 block fits.
5807 // Set (x3 address, x4 stride) to write to dst_tmp_buf
5808 "mov x3, %[dst_tmp_buf]\n"
5809 "mov x4, #8\n"
5810 "b 131f\n"
5811 "130:\n"
5812 // Yes, all of the 8x8 block fits.
5813 // Set (x3 address, x4 stride) to write directly to destination matrix.
5814 "mov x3, %[dst_ptr]\n"
5815 "mov x4, x11\n"
5816 "131:\n"
5817
5818 // Write our 8bit values to the destination described by
5819 // (x3 address, x4 stride).
5820 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5821 "st1 {v16.8b}, [x3], x4\n"
5822 RUY_MAKE_ZERO(v16)
5823 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5824 "st1 {v20.8b}, [x3], x4\n"
5825 RUY_MAKE_ZERO(v20)
5826 // For the next block: perform the first few multiply-adds on the data
5827 // that we have already loaded.
5828 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5829 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5830 "st1 {v17.8b}, [x3], x4\n"
5831 RUY_MAKE_ZERO(v17)
5832 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5833 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5834 "st1 {v21.8b}, [x3], x4\n"
5835 RUY_MAKE_ZERO(v21)
5836 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5837 "st1 {v18.8b}, [x3], x4\n"
5838 RUY_MAKE_ZERO(v18)
5839 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5840 "st1 {v22.8b}, [x3], x4\n"
5841 RUY_MAKE_ZERO(v22)
5842 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5843 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5844 "st1 {v19.8b}, [x3], x4\n"
5845 RUY_MAKE_ZERO(v19)
5846 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5847 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5848 "st1 {v23.8b}, [x3], x4\n"
5849 RUY_MAKE_ZERO(v23)
5850
5851 // If all of the 8x8 block fits, we just finished writing it to the
5852 // destination, so we skip the next part.
5853 "beq 141f\n"
5854 // Not all of the 8x8 block fits in the destination matrix. We just
5855 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
5856 // it to copy into the destination matrix the part that fits.
5857 "mov x3, %[dst_tmp_buf]\n"
5858 "mov x4, %[dst_ptr]\n"
5859 "mov w6, #0\n"
5860 "150:\n"
5861 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
5862 "mov w5, #0\n"
5863 "151:\n"
5864 "ldrb w7, [x3, w5, uxtw]\n"
5865 "strb w7, [x4, w5, uxtw]\n"
5866 "add w5, w5, #1\n"
5867 "cmp w5, w1\n"
5868 "blt 151b\n"
5869 "add w6, w6, #1\n"
5870 "add x3, x3, #8\n"
5871 "add x4, x4, x11\n"
5872 "cmp w6, w2\n"
5873 "blt 150b\n"
5874 "141:\n"
5875 "add %[dst_ptr], %[dst_ptr], #8\n"
5876
5877 // At this point we have completely finished writing values to the
5878 // destination matrix for the current block.
5879
5880 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
5881
5882 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
5883
5884 // Add the destination zero point
5885 "dup v14.8h, v13.h[4]\n"
5886 "saddw v16.4s, v16.4s, v14.4h\n"
5887 "saddw v17.4s, v17.4s, v14.4h\n"
5888 "saddw v18.4s, v18.4s, v14.4h\n"
5889 "saddw v19.4s, v19.4s, v14.4h\n"
5890 "saddw v20.4s, v20.4s, v14.4h\n"
5891 "saddw v21.4s, v21.4s, v14.4h\n"
5892 "saddw v22.4s, v22.4s, v14.4h\n"
5893 "saddw v23.4s, v23.4s, v14.4h\n"
5894 "saddw v24.4s, v24.4s, v14.4h\n"
5895 "saddw v25.4s, v25.4s, v14.4h\n"
5896 "saddw v26.4s, v26.4s, v14.4h\n"
5897 "saddw v27.4s, v27.4s, v14.4h\n"
5898 "saddw v28.4s, v28.4s, v14.4h\n"
5899 "saddw v29.4s, v29.4s, v14.4h\n"
5900 "saddw v30.4s, v30.4s, v14.4h\n"
5901 "saddw v31.4s, v31.4s, v14.4h\n"
5902
5903 // Cast-and-saturate from int32 to int16
5904 "sqxtn v16.4h, v16.4s\n"
5905 "sqxtn2 v16.8h, v17.4s\n"
5906 "sqxtn v17.4h, v18.4s\n"
5907 "sqxtn2 v17.8h, v19.4s\n"
5908 "sqxtn v18.4h, v20.4s\n"
5909 "sqxtn2 v18.8h, v21.4s\n"
5910 "sqxtn v19.4h, v22.4s\n"
5911 "sqxtn2 v19.8h, v23.4s\n"
5912 "sqxtn v20.4h, v24.4s\n"
5913 "sqxtn2 v20.8h, v25.4s\n"
5914 "sqxtn v21.4h, v26.4s\n"
5915 "sqxtn2 v21.8h, v27.4s\n"
5916 "sqxtn v22.4h, v28.4s\n"
5917 "sqxtn2 v22.8h, v29.4s\n"
5918 "sqxtn v23.4h, v30.4s\n"
5919 "sqxtn2 v23.8h, v31.4s\n"
5920
5921 // At this point, v24 -- v31 aren't used anymore for the current block,
5922 // so we can start clearing these accumulators for the next block
5923 // (next iteration of the main loop).
5924 RUY_MAKE_ZERO(v24)
5925 RUY_MAKE_ZERO(v25)
5926 RUY_MAKE_ZERO(v26)
5927 RUY_MAKE_ZERO(v27)
5928 RUY_MAKE_ZERO(v28)
5929 RUY_MAKE_ZERO(v29)
5930 RUY_MAKE_ZERO(v30)
5931 RUY_MAKE_ZERO(v31)
5932
5933 // Load the clamp_min, clamp_max bounds
5934 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
5935 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
5936 "dup v14.8h, w2\n" // clamp_min
5937 "dup v15.8h, w3\n" // clamp_max
5938
5939 // Apply the clamp_min bound
5940 "smax v16.8h, v16.8h, v14.8h\n"
5941 "smax v17.8h, v17.8h, v14.8h\n"
5942 "smax v18.8h, v18.8h, v14.8h\n"
5943 "smax v19.8h, v19.8h, v14.8h\n"
5944 "smax v20.8h, v20.8h, v14.8h\n"
5945 "smax v21.8h, v21.8h, v14.8h\n"
5946 "smax v22.8h, v22.8h, v14.8h\n"
5947 "smax v23.8h, v23.8h, v14.8h\n"
5948 // Apply the clamp_max bound
5949 "smin v16.8h, v16.8h, v15.8h\n"
5950 "smin v17.8h, v17.8h, v15.8h\n"
5951 "smin v18.8h, v18.8h, v15.8h\n"
5952 "smin v19.8h, v19.8h, v15.8h\n"
5953 "smin v20.8h, v20.8h, v15.8h\n"
5954 "smin v21.8h, v21.8h, v15.8h\n"
5955 "smin v22.8h, v22.8h, v15.8h\n"
5956 "smin v23.8h, v23.8h, v15.8h\n"
5957
5958 // Compute how much of the 8x8 block of destination 16bit values that
5959 // we have computed, fit in the destination matrix. Typically, all of
5960 // it fits, but when the destination matrix shape is not a multiple
5961 // of 8x8, there are some 8x8 blocks along the boundaries that do
5962 // not fit entirely.
5963 "sub w1, %w[dst_rows], %w[row]\n"
5964 "sub w2, %w[dst_cols], %w[col]\n"
5965 "mov w3, #8\n"
5966 "cmp w1, #8\n"
5967 // Compute w1 = how many rows of the 8x8 block fit
5968 "csel w1, w1, w3, le\n"
5969 "cmp w2, #8\n"
5970 // Compute w1 = how many rows of the 8x8 block fit
5971 "csel w2, w2, w3, le\n"
5972
5973 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
5974 "cmp w1, w3\n"
5975 "ccmp w2, w3, 0, eq\n"
5976 // Yes, all of the 8x8 block fits, go to fast path.
5977 "beq 230f\n"
5978 // Not all of the 8x8 block fits.
5979 // Set (x3 address, x4 stride) to write to dst_tmp_buf
5980 "mov x3, %[dst_tmp_buf]\n"
5981 "mov x4, #16\n"
5982 "b 231f\n"
5983 "230:\n"
5984 // Yes, all of the 8x8 block fits.
5985 // Set (x3 address, x4 stride) to write directly to destination matrix.
5986 "mov x3, %[dst_ptr]\n"
5987 "mov x4, x11\n"
5988 "231:\n"
5989
5990 // Write our 8bit values to the destination described by
5991 // (x3 address, x4 stride).
5992 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5993 "st1 {v16.8h}, [x3], x4\n"
5994 RUY_MAKE_ZERO(v16)
5995 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5996 "st1 {v17.8h}, [x3], x4\n"
5997 RUY_MAKE_ZERO(v17)
5998 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5999 "st1 {v18.8h}, [x3], x4\n"
6000 RUY_MAKE_ZERO(v18)
6001 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6002 "st1 {v19.8h}, [x3], x4\n"
6003 RUY_MAKE_ZERO(v19)
6004 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6005 "st1 {v20.8h}, [x3], x4\n"
6006 RUY_MAKE_ZERO(v20)
6007 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6008 "st1 {v21.8h}, [x3], x4\n"
6009 RUY_MAKE_ZERO(v21)
6010 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6011 "st1 {v22.8h}, [x3], x4\n"
6012 RUY_MAKE_ZERO(v22)
6013 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6014 "st1 {v23.8h}, [x3], x4\n"
6015 RUY_MAKE_ZERO(v23)
6016
6017 // For the next block: perform the first few multiply-adds on the data
6018 // that we have already loaded.
6019 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
6020 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
6021 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
6022 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
6023
6024 // If all of the 8x8 block fits, we just finished writing it to the
6025 // destination, so we skip the next part.
6026 "beq 241f\n"
6027 // Not all of the 8x8 block fits in the destination matrix. We just
6028 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
6029 // it to copy into the destination matrix the part that fits.
6030 "mov x3, %[dst_tmp_buf]\n"
6031 "mov x4, %[dst_ptr]\n"
6032 "mov w6, #0\n"
6033 "250:\n"
6034 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6035 "mov w5, #0\n"
6036 "251:\n"
6037 "ldrsh w7, [x3, x5, lsl #1]\n"
6038 "strh w7, [x4, x5, lsl #1]\n"
6039 "add w5, w5, #1\n"
6040 "cmp w5, w1\n"
6041 "blt 251b\n"
6042 "add w6, w6, #1\n"
6043 "add x3, x3, #16\n"
6044 "add x4, x4, x11\n"
6045 "cmp w6, w2\n"
6046 "blt 250b\n"
6047 "241:\n"
6048 "add %[dst_ptr], %[dst_ptr], #16\n"
6049 // At this point we have completely finished writing values to the
6050 // destination matrix for the current block.
6051
6052 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
6053
6054 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
6055
6056 "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
6057 "ldr x1, [%[lhs_ptr]], #8\n"
6058 "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
6059 "ldr x2, [%[lhs_ptr]], #8\n"
6060 "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
6061 "ldr x5, [%[rhs_ptr]], #8\n"
6062 "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
6063 "ldr x6, [%[rhs_ptr]], #8\n"
6064 "ins v0.d[1], x1\n"
6065 "ins v1.d[1], x2\n"
6066 "ins v2.d[1], x5\n"
6067 "ins v3.d[1], x6\n"
6068
6069 // Since the store type is the same as the accum type, no need for
6070 // downcast. There's also no need for clamp by min/max.
6071
6072 // Compute how much of the 8x8 block of destination 32it values that
6073 // we have computed, fit in the destination matrix. Typically, all of
6074 // it fits, but when the destination matrix shape is not a multiple
6075 // of 8x8, there are some 8x8 blocks along the boundaries that do
6076 // not fit entirely.
6077 "sub w1, %w[dst_rows], %w[row]\n"
6078 "sub w2, %w[dst_cols], %w[col]\n"
6079 "mov w3, #8\n"
6080 "cmp w1, #8\n"
6081 // Compute w1 = how many rows of the 8x8 block fit
6082 "csel w1, w1, w3, le\n"
6083 "cmp w2, #8\n"
6084 // Compute w1 = how many rows of the 8x8 block fit
6085 "csel w2, w2, w3, le\n"
6086
6087 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
6088 "cmp w1, w3\n"
6089 "ccmp w2, w3, 0, eq\n"
6090 // Yes, all of the 8x8 block fits, go to fast path.
6091 "beq 330f\n"
6092 // Not all of the 8x8 block fits.
6093 // Write to dst_tmp_buf
6094 "mov x3, %[dst_tmp_buf]\n"
6095 "st1 {v16.4s}, [x3], #16\n"
6096 RUY_MAKE_ZERO(v16)
6097 "st1 {v17.4s}, [x3], #16\n"
6098 RUY_MAKE_ZERO(v17)
6099 "st1 {v18.4s}, [x3], #16\n"
6100 RUY_MAKE_ZERO(v18)
6101 "st1 {v19.4s}, [x3], #16\n"
6102 RUY_MAKE_ZERO(v19)
6103 "st1 {v20.4s}, [x3], #16\n"
6104 RUY_MAKE_ZERO(v20)
6105 "st1 {v21.4s}, [x3], #16\n"
6106 RUY_MAKE_ZERO(v21)
6107 "st1 {v22.4s}, [x3], #16\n"
6108 RUY_MAKE_ZERO(v22)
6109 "st1 {v23.4s}, [x3], #16\n"
6110 RUY_MAKE_ZERO(v23)
6111 "st1 {v24.4s}, [x3], #16\n"
6112 RUY_MAKE_ZERO(v24)
6113 "st1 {v25.4s}, [x3], #16\n"
6114 RUY_MAKE_ZERO(v25)
6115 "st1 {v26.4s}, [x3], #16\n"
6116 RUY_MAKE_ZERO(v26)
6117 "st1 {v27.4s}, [x3], #16\n"
6118 RUY_MAKE_ZERO(v27)
6119 "st1 {v28.4s}, [x3], #16\n"
6120 RUY_MAKE_ZERO(v28)
6121 "st1 {v29.4s}, [x3], #16\n"
6122 RUY_MAKE_ZERO(v29)
6123 "st1 {v30.4s}, [x3], #16\n"
6124 RUY_MAKE_ZERO(v30)
6125 "st1 {v31.4s}, [x3], #16\n"
6126 RUY_MAKE_ZERO(v31)
6127
6128 "b 331f\n"
6129
6130 "330:\n"
6131 // Yes, all of the 8x8 block fits.
6132 "mov x4, %[dst_ptr]\n"
6133 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6134 "st1 {v16.4s, v17.4s}, [x4], x11\n"
6135 RUY_MAKE_ZERO(v16)
6136 RUY_MAKE_ZERO(v17)
6137 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6138 "st1 {v18.4s, v19.4s}, [x4], x11\n"
6139 RUY_MAKE_ZERO(v18)
6140 RUY_MAKE_ZERO(v19)
6141 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6142 "st1 {v20.4s, v21.4s}, [x4], x11\n"
6143 RUY_MAKE_ZERO(v20)
6144 RUY_MAKE_ZERO(v21)
6145 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6146 "st1 {v22.4s, v23.4s}, [x4], x11\n"
6147 RUY_MAKE_ZERO(v22)
6148 RUY_MAKE_ZERO(v23)
6149 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6150 "st1 {v24.4s, v25.4s}, [x4], x11\n"
6151 RUY_MAKE_ZERO(v24)
6152 RUY_MAKE_ZERO(v25)
6153 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6154 "st1 {v26.4s, v27.4s}, [x4], x11\n"
6155 RUY_MAKE_ZERO(v26)
6156 RUY_MAKE_ZERO(v27)
6157 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6158 "st1 {v28.4s, v29.4s}, [x4], x11\n"
6159 RUY_MAKE_ZERO(v28)
6160 RUY_MAKE_ZERO(v29)
6161 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6162 "st1 {v30.4s, v31.4s}, [x4], x11\n"
6163 RUY_MAKE_ZERO(v30)
6164 RUY_MAKE_ZERO(v31)
6165
6166 "331:\n"
6167
6168 // For the next block: perform the first few multiply-adds on the data
6169 // that we have already loaded.
6170 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
6171 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
6172 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
6173 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
6174
6175 // If all of the 8x8 block fits, we just finished writing it to the
6176 // destination, so we skip the next part.
6177 "beq 341f\n"
6178
6179 // Not all of the 8x8 block fits in the destination matrix. We just
6180 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
6181 // it to copy into the destination matrix the part that fits.
6182 "mov x3, %[dst_tmp_buf]\n"
6183 "mov x4, %[dst_ptr]\n"
6184 "mov w6, #0\n"
6185 "350:\n"
6186 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6187 "mov w5, #0\n"
6188 "351:\n"
6189 "ldr w7, [x3, x5, lsl #2]\n"
6190 "str w7, [x4, x5, lsl #2]\n"
6191 "add w5, w5, #1\n"
6192 "cmp w5, w1\n"
6193 "blt 351b\n"
6194 "add w6, w6, #1\n"
6195 "add x3, x3, #32\n"
6196 "add x4, x4, x11\n"
6197 "cmp w6, w2\n"
6198 "blt 350b\n"
6199 "341:\n"
6200 "add %[dst_ptr], %[dst_ptr], #32\n"
6201 // At this point we have completely finished writing values to the
6202 // destination matrix for the current block.
6203
6204 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
6205
6206 // Reload some params --- we had used x5 -- x7 for a few other things
6207 // since the last time we had loaded them.
6208 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6209 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6210
6211 // Move to the next block of the destination matrix, for the next iter
6212 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
6213 // been updated earlier.
6214 // Have we reached the end row?
6215 "cmp %w[row], w7\n"
6216 "beq 20f\n" // yes, end row.
6217 // Not end row. Move to the next row.
6218 "add %w[row], %w[row], #8\n"
6219 "b 21f\n"
6220 "20:\n"
6221 // Was already at end row.
6222 "mov %w[row], w6\n" // Move back to first row.
6223 "add %w[col], %w[col], #8\n" // Move to the next column.
6224 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
6225 "mov %[dst_ptr], %[dst_col_ptr]\n"
6226 "21:\n"
6227
6228 // Main loop exit condition: have we hit the end column?
6229 "cmp %w[col], w8\n"
6230 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6231 "ble 1b\n"
6232
6233 // clang-format on
6234
6235 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
6236 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
6237 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
6238 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
6239 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
6240 [dst_type_id] "r"(params.dst_type_id)
6241 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
6242 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
6243 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
6244 "v26", "v27", "v28", "v29", "v30", "v31");
6245 }
6246 #undef RUY_OFFSET_BIAS
6247 #undef RUY_OFFSET_LHS_SUMS
6248 #undef RUY_OFFSET_RHS_SUMS
6249 #undef RUY_OFFSET_LHS_BASE_PTR
6250 #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
6251 #undef RUY_OFFSET_MULTIPLIER_EXPONENT
6252 #undef RUY_OFFSET_RHS_BASE_PTR
6253 #undef RUY_OFFSET_DST_BASE_PTR
6254 #undef RUY_OFFSET_LHS_ZERO_POINT
6255 #undef RUY_OFFSET_RHS_ZERO_POINT
6256 #undef RUY_OFFSET_DST_ZERO_POINT
6257 #undef RUY_OFFSET_PROD_ZP_DEPTH
6258 #undef RUY_OFFSET_START_ROW
6259 #undef RUY_OFFSET_START_COL
6260 #undef RUY_OFFSET_LAST_ROW
6261 #undef RUY_OFFSET_LAST_COL
6262 #undef RUY_OFFSET_DST_ROWS
6263 #undef RUY_OFFSET_DST_COLS
6264 #undef RUY_OFFSET_LHS_STRIDE
6265 #undef RUY_OFFSET_RHS_STRIDE
6266 #undef RUY_OFFSET_DST_STRIDE
6267 #undef RUY_OFFSET_DEPTH
6268 #undef RUY_OFFSET_CLAMP_MIN
6269 #undef RUY_OFFSET_CLAMP_MAX
6270 #undef RUY_OFFSET_FLAGS
6271
6272 #define RUY_OFFSET_LHS_BASE_PTR 0
6273 #define RUY_OFFSET_RHS_BASE_PTR 8
6274 #define RUY_OFFSET_DST_BASE_PTR 16
6275 #define RUY_OFFSET_BIAS 24
6276 #define RUY_OFFSET_START_ROW 32
6277 #define RUY_OFFSET_START_COL 36
6278 #define RUY_OFFSET_LAST_ROW 40
6279 #define RUY_OFFSET_LAST_COL 44
6280 #define RUY_OFFSET_LHS_STRIDE 56
6281 #define RUY_OFFSET_RHS_STRIDE 60
6282 #define RUY_OFFSET_DST_STRIDE 64
6283 #define RUY_OFFSET_DEPTH 68
6284 #define RUY_OFFSET_CLAMP_MIN 72
6285 #define RUY_OFFSET_CLAMP_MAX 76
6286 #define RUY_OFFSET_FLAGS 80
6287
6288 template <typename Params>
CheckOffsetsInKernelParamsFloat(const Params &)6289 void CheckOffsetsInKernelParamsFloat(const Params&) {
6290 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
6291 static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
6292 static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
6293 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
6294 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
6295 static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
6296 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
6297 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
6298 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
6299 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
6300 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
6301 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
6302 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
6303 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
6304 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
6305 }
6306
6307 // Just a plain float kernel; good enough for out-of-order cores.
6308 // The closest to it in the gemmlowp collection would be
6309 // NEON_64bit_GEMM_Float32_WithScalar,
6310 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925
6311 //
6312 // Besides ruy-ification, the main nuance here is that we stick to a 8x8
6313 // width instead of the wider 12x8 that the register space permits and that
6314 // the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now
6315 // and we don't have evidence that going beyond 8x8 is needed.
KernelFloatNeonOutOfOrder(const KernelParamsFloat<8,8> & params)6316 void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) {
6317 CheckOffsetsInKernelParamsFloat(params);
6318 profiler::ScopeLabel label(
6319 "Kernel (kNeon, optimized for out-of-order cores)");
6320
6321 const float* lhs_col_ptr = params.lhs_base_ptr;
6322 const float* rhs_col_ptr = params.rhs_base_ptr;
6323 const float* lhs_ptr = lhs_col_ptr;
6324 const float* rhs_ptr = rhs_col_ptr;
6325 float* dst_col_ptr = params.dst_base_ptr;
6326 float* dst_ptr = dst_col_ptr;
6327 int row = params.start_row;
6328 int col = params.start_col;
6329
6330 // The asm kernel below has the following NEON register allocation:
6331 //
6332 // v16 -- v31 are accumulators.
6333 // During accumulation, v0 -- v15 are used to load data from LHS and RHS.
6334 // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
6335 // v3 are used to load a 1x8 block of RHS, like this:
6336 //
6337 // RHS 1x8 block
6338 // /-----------------------------------------\
6339 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
6340 // \-----------------------------------------/
6341 // LHS 8x1 block
6342 // /---------------------\ /-----------------------------------------\
6343 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
6344 // | ... | | ... ... |
6345 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
6346 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
6347 // | ... | | ... ... |
6348 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
6349 // \---------------------/ \-----------------------------------------/
6350 // accumulators 8x8 block
6351 //
6352 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
6353 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
6354 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
6355 //
6356 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
6357 // unused, and v8 -- v15 are used for floading parameters used for the
6358 // post-accumulation part of the kernel.
6359 asm volatile(
6360 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
6361
6362 // clang-format off
6363
6364 // Load some parameters into registers.
6365 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6366 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6367 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6368 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
6369 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
6370 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
6371 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
6372 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
6373
6374 // Load the first 32 bytes of LHS and RHS data.
6375 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6376 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6377 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
6378 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6379
6380 // Clear accumulators.
6381 RUY_MAKE_ZERO(v16)
6382 RUY_MAKE_ZERO(v17)
6383 RUY_MAKE_ZERO(v18)
6384 RUY_MAKE_ZERO(v19)
6385 RUY_MAKE_ZERO(v20)
6386 RUY_MAKE_ZERO(v21)
6387 RUY_MAKE_ZERO(v22)
6388 RUY_MAKE_ZERO(v23)
6389 RUY_MAKE_ZERO(v24)
6390 RUY_MAKE_ZERO(v25)
6391 RUY_MAKE_ZERO(v26)
6392 RUY_MAKE_ZERO(v27)
6393 RUY_MAKE_ZERO(v28)
6394 RUY_MAKE_ZERO(v29)
6395 RUY_MAKE_ZERO(v30)
6396 RUY_MAKE_ZERO(v31)
6397
6398 // w1 is the number of levels of depth that we have already loaded
6399 // LHS and RHS data for. Corresponding to the initial ld1 instructions
6400 // above, this is currently 1.
6401 "mov w1, #1\n"
6402
6403 // Main loop of the whole GEMM, over rows and columns of the
6404 // destination matrix.
6405 "1:\n"
6406
6407 "fmla v16.4s, v0.4s, v2.s[0]\n"
6408 "fmla v18.4s, v0.4s, v2.s[1]\n"
6409 "fmla v20.4s, v0.4s, v2.s[2]\n"
6410 "fmla v22.4s, v0.4s, v2.s[3]\n"
6411
6412 #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
6413 "cmp w12, #8\n"
6414 "blt 78f\n"
6415 "and w2, w12, #-4\n"
6416
6417 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
6418 "ld1 {v5.4s}, [%[lhs_ptr]], #16\n"
6419 "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
6420 "ld1 {v7.4s}, [%[rhs_ptr]], #16\n"
6421
6422 "ld1 {v8.4s}, [%[lhs_ptr]], #16\n"
6423 "ld1 {v9.4s}, [%[lhs_ptr]], #16\n"
6424 "ld1 {v10.4s}, [%[rhs_ptr]], #16\n"
6425 "ld1 {v11.4s}, [%[rhs_ptr]], #16\n"
6426
6427 "ld1 {v12.4s}, [%[lhs_ptr]], #16\n"
6428 "ld1 {v13.4s}, [%[lhs_ptr]], #16\n"
6429 "ld1 {v14.4s}, [%[rhs_ptr]], #16\n"
6430 "ld1 {v15.4s}, [%[rhs_ptr]], #16\n"
6431 "mov w1, #4\n"
6432
6433 "80:\n"
6434
6435 "add %[lhs_ptr], %[lhs_ptr], #128\n"
6436 "add %[rhs_ptr], %[rhs_ptr], #128\n"
6437
6438 "fmla v24.4s, v0.4s, v3.s[0]\n"
6439 "fmla v26.4s, v0.4s, v3.s[1]\n"
6440 "fmla v28.4s, v0.4s, v3.s[2]\n"
6441 "fmla v30.4s, v0.4s, v3.s[3]\n"
6442 "ldr q0, [%[lhs_ptr], #-128]\n"
6443 "fmla v25.4s, v1.4s, v3.s[0]\n"
6444 "fmla v27.4s, v1.4s, v3.s[1]\n"
6445 "fmla v29.4s, v1.4s, v3.s[2]\n"
6446 "fmla v31.4s, v1.4s, v3.s[3]\n"
6447 "ldr q3, [%[rhs_ptr], #-112]\n"
6448 "fmla v17.4s, v1.4s, v2.s[0]\n"
6449 "fmla v19.4s, v1.4s, v2.s[1]\n"
6450 "fmla v21.4s, v1.4s, v2.s[2]\n"
6451 "fmla v23.4s, v1.4s, v2.s[3]\n"
6452 "ldr q1, [%[lhs_ptr], #-112]\n"
6453 "fmla v16.4s, v4.4s, v6.s[0]\n"
6454 "fmla v18.4s, v4.4s, v6.s[1]\n"
6455 "ldr q2, [%[rhs_ptr], #-128]\n"
6456 "fmla v20.4s, v4.4s, v6.s[2]\n"
6457 "fmla v22.4s, v4.4s, v6.s[3]\n"
6458
6459 "fmla v24.4s, v4.4s, v7.s[0]\n"
6460 "fmla v26.4s, v4.4s, v7.s[1]\n"
6461 "fmla v28.4s, v4.4s, v7.s[2]\n"
6462 "fmla v30.4s, v4.4s, v7.s[3]\n"
6463 "ldr q4, [%[lhs_ptr], #-96]\n"
6464 "fmla v25.4s, v5.4s, v7.s[0]\n"
6465 "fmla v27.4s, v5.4s, v7.s[1]\n"
6466 "fmla v29.4s, v5.4s, v7.s[2]\n"
6467 "fmla v31.4s, v5.4s, v7.s[3]\n"
6468 "ldr q7, [%[rhs_ptr], #-80]\n"
6469 "fmla v17.4s, v5.4s, v6.s[0]\n"
6470 "fmla v19.4s, v5.4s, v6.s[1]\n"
6471 "fmla v21.4s, v5.4s, v6.s[2]\n"
6472 "fmla v23.4s, v5.4s, v6.s[3]\n"
6473 "ldr q5, [%[lhs_ptr], #-80]\n"
6474 "fmla v16.4s, v8.4s, v10.s[0]\n"
6475 "fmla v18.4s, v8.4s, v10.s[1]\n"
6476 "ldr q6, [%[rhs_ptr], #-96]\n"
6477 "fmla v20.4s, v8.4s, v10.s[2]\n"
6478 "fmla v22.4s, v8.4s, v10.s[3]\n"
6479
6480 "fmla v24.4s, v8.4s, v11.s[0]\n"
6481 "fmla v26.4s, v8.4s, v11.s[1]\n"
6482 "fmla v28.4s, v8.4s, v11.s[2]\n"
6483 "fmla v30.4s, v8.4s, v11.s[3]\n"
6484 "ldr q8, [%[lhs_ptr], #-64]\n"
6485 "fmla v25.4s, v9.4s, v11.s[0]\n"
6486 "fmla v27.4s, v9.4s, v11.s[1]\n"
6487 "fmla v29.4s, v9.4s, v11.s[2]\n"
6488 "fmla v31.4s, v9.4s, v11.s[3]\n"
6489 "ldr q11, [%[rhs_ptr], #-48]\n"
6490 "fmla v17.4s, v9.4s, v10.s[0]\n"
6491 "fmla v19.4s, v9.4s, v10.s[1]\n"
6492 "fmla v21.4s, v9.4s, v10.s[2]\n"
6493 "fmla v23.4s, v9.4s, v10.s[3]\n"
6494 "ldr q9, [%[lhs_ptr], #-48]\n"
6495 "fmla v16.4s, v12.4s, v14.s[0]\n"
6496 "fmla v18.4s, v12.4s, v14.s[1]\n"
6497 "ldr q10, [%[rhs_ptr], #-64]\n"
6498 "fmla v20.4s, v12.4s, v14.s[2]\n"
6499 "fmla v22.4s, v12.4s, v14.s[3]\n"
6500
6501 "fmla v24.4s, v12.4s, v15.s[0]\n"
6502 "fmla v26.4s, v12.4s, v15.s[1]\n"
6503 "fmla v28.4s, v12.4s, v15.s[2]\n"
6504 "fmla v30.4s, v12.4s, v15.s[3]\n"
6505 "ldr q12, [%[lhs_ptr], #-32]\n"
6506 "fmla v25.4s, v13.4s, v15.s[0]\n"
6507 "fmla v27.4s, v13.4s, v15.s[1]\n"
6508 "fmla v29.4s, v13.4s, v15.s[2]\n"
6509 "fmla v31.4s, v13.4s, v15.s[3]\n"
6510 "ldr q15, [%[rhs_ptr], #-16]\n"
6511 "fmla v17.4s, v13.4s, v14.s[0]\n"
6512 "fmla v19.4s, v13.4s, v14.s[1]\n"
6513 "fmla v21.4s, v13.4s, v14.s[2]\n"
6514 "fmla v23.4s, v13.4s, v14.s[3]\n"
6515 "ldr q13, [%[lhs_ptr], #-16]\n"
6516 "fmla v16.4s, v0.4s, v2.s[0]\n"
6517 "fmla v18.4s, v0.4s, v2.s[1]\n"
6518 "ldr q14, [%[rhs_ptr], #-32]\n"
6519 "fmla v20.4s, v0.4s, v2.s[2]\n"
6520 "fmla v22.4s, v0.4s, v2.s[3]\n"
6521
6522 "add w1, w1, #4\n"
6523 "cmp w1, w2\n"
6524 "blt 80b\n"
6525
6526 "fmla v16.4s, v4.4s, v6.s[0]\n"
6527 "fmla v18.4s, v4.4s, v6.s[1]\n"
6528 "fmla v20.4s, v4.4s, v6.s[2]\n"
6529 "fmla v22.4s, v4.4s, v6.s[3]\n"
6530 "fmla v24.4s, v4.4s, v7.s[0]\n"
6531 "fmla v26.4s, v4.4s, v7.s[1]\n"
6532 "fmla v28.4s, v4.4s, v7.s[2]\n"
6533 "fmla v30.4s, v4.4s, v7.s[3]\n"
6534 "fmla v25.4s, v5.4s, v7.s[0]\n"
6535 "fmla v27.4s, v5.4s, v7.s[1]\n"
6536 "fmla v29.4s, v5.4s, v7.s[2]\n"
6537 "fmla v31.4s, v5.4s, v7.s[3]\n"
6538 "fmla v17.4s, v5.4s, v6.s[0]\n"
6539 "fmla v19.4s, v5.4s, v6.s[1]\n"
6540 "fmla v21.4s, v5.4s, v6.s[2]\n"
6541 "fmla v23.4s, v5.4s, v6.s[3]\n"
6542
6543 "fmla v16.4s, v8.4s, v10.s[0]\n"
6544 "fmla v18.4s, v8.4s, v10.s[1]\n"
6545 "fmla v20.4s, v8.4s, v10.s[2]\n"
6546 "fmla v22.4s, v8.4s, v10.s[3]\n"
6547 "fmla v24.4s, v8.4s, v11.s[0]\n"
6548 "fmla v26.4s, v8.4s, v11.s[1]\n"
6549 "fmla v28.4s, v8.4s, v11.s[2]\n"
6550 "fmla v30.4s, v8.4s, v11.s[3]\n"
6551 "fmla v25.4s, v9.4s, v11.s[0]\n"
6552 "fmla v27.4s, v9.4s, v11.s[1]\n"
6553 "fmla v29.4s, v9.4s, v11.s[2]\n"
6554 "fmla v31.4s, v9.4s, v11.s[3]\n"
6555 "fmla v17.4s, v9.4s, v10.s[0]\n"
6556 "fmla v19.4s, v9.4s, v10.s[1]\n"
6557 "fmla v21.4s, v9.4s, v10.s[2]\n"
6558 "fmla v23.4s, v9.4s, v10.s[3]\n"
6559
6560 "fmla v16.4s, v12.4s, v14.s[0]\n"
6561 "fmla v18.4s, v12.4s, v14.s[1]\n"
6562 "fmla v20.4s, v12.4s, v14.s[2]\n"
6563 "fmla v22.4s, v12.4s, v14.s[3]\n"
6564 "fmla v24.4s, v12.4s, v15.s[0]\n"
6565 "fmla v26.4s, v12.4s, v15.s[1]\n"
6566 "fmla v28.4s, v12.4s, v15.s[2]\n"
6567 "fmla v30.4s, v12.4s, v15.s[3]\n"
6568 "fmla v25.4s, v13.4s, v15.s[0]\n"
6569 "fmla v27.4s, v13.4s, v15.s[1]\n"
6570 "fmla v29.4s, v13.4s, v15.s[2]\n"
6571 "fmla v31.4s, v13.4s, v15.s[3]\n"
6572 "fmla v17.4s, v13.4s, v14.s[0]\n"
6573 "fmla v19.4s, v13.4s, v14.s[1]\n"
6574 "fmla v21.4s, v13.4s, v14.s[2]\n"
6575 "fmla v23.4s, v13.4s, v14.s[3]\n"
6576
6577 "78:\n"
6578 #endif
6579
6580 // Accumulation loop
6581 "cmp w1, w12\n"
6582 "beq 79f\n"
6583
6584 "2:\n"
6585 "fmla v24.4s, v0.4s, v3.s[0]\n"
6586 "fmla v26.4s, v0.4s, v3.s[1]\n"
6587 "ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
6588 "fmla v28.4s, v0.4s, v3.s[2]\n"
6589 "fmla v30.4s, v0.4s, v3.s[3]\n"
6590 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6591 "fmla v25.4s, v1.4s, v3.s[0]\n"
6592 "fmla v27.4s, v1.4s, v3.s[1]\n"
6593 "add w1, w1, #1\n"
6594 "fmla v29.4s, v1.4s, v3.s[2]\n"
6595 "fmla v31.4s, v1.4s, v3.s[3]\n"
6596 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6597 "fmla v17.4s, v1.4s, v2.s[0]\n"
6598 "fmla v19.4s, v1.4s, v2.s[1]\n"
6599 "cmp w1, w12\n"
6600 "fmla v21.4s, v1.4s, v2.s[2]\n"
6601 "fmla v23.4s, v1.4s, v2.s[3]\n"
6602 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6603 "fmla v16.4s, v0.4s, v4.s[0]\n"
6604 "fmla v18.4s, v0.4s, v4.s[1]\n"
6605 "mov v2.16b, v4.16b\n"
6606 "fmla v20.4s, v0.4s, v4.s[2]\n"
6607 "fmla v22.4s, v0.4s, v4.s[3]\n"
6608 "blt 2b\n"
6609
6610 "79:\n"
6611
6612 // End of the inner loop on depth. Now perform the remaining
6613 // multiply-adds of the last level of depth, for which the LHS
6614 // and RHS data is already loaded.
6615
6616 "fmla v24.4s, v0.4s, v3.s[0]\n"
6617 "fmla v26.4s, v0.4s, v3.s[1]\n"
6618 "fmla v28.4s, v0.4s, v3.s[2]\n"
6619 "fmla v30.4s, v0.4s, v3.s[3]\n"
6620 "fmla v25.4s, v1.4s, v3.s[0]\n"
6621 "fmla v27.4s, v1.4s, v3.s[1]\n"
6622 "fmla v29.4s, v1.4s, v3.s[2]\n"
6623 "fmla v31.4s, v1.4s, v3.s[3]\n"
6624 "fmla v17.4s, v1.4s, v2.s[0]\n"
6625 "fmla v19.4s, v1.4s, v2.s[1]\n"
6626 "fmla v21.4s, v1.4s, v2.s[2]\n"
6627 "fmla v23.4s, v1.4s, v2.s[3]\n"
6628
6629 // End of accumulation. The registers v16 -- v31 contain the final
6630 // int32 accumulator values of the current 8x8 destination block.
6631 // We now have to compute the final 8-bit values from these int32
6632 // accumulators, and advance to the next 8x8 block. We intertwine
6633 // these two aspects whenever possible for optimal pipelining, both
6634 // at the data flow level (prefetch data for next block as early as
6635 // possible) and instruction pipelining level (some of the next-block
6636 // work can dual-issue with some of the final work on the current
6637 // block).
6638
6639 // Logic to advance to the next block in preparation for the next
6640 // iteration of the main loop. For now, we only want to compute
6641 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
6642 // not yet ready to update the values of row and col, as we still need
6643 // the current values for the rest of the work on the current block.
6644
6645 "cmp %w[row], w7\n" // Have we finished the last row?
6646 "bge 4f\n" // If finished last row, go to 4
6647 // Not finished last row: then advance to next row.
6648 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
6649 "b 5f\n"
6650 "4:\n" // Finished last row...
6651 "mov %[lhs_col_ptr], x5\n" // Go back to first row
6652 // Now we need to advance to the next column. If we already
6653 // finished the last column, then in principle we are done, however
6654 // we can't just return here, as we need to allow the end work of the
6655 // current block to complete. The good news is that at this point it
6656 // doesn't matter what data we load for the next column, since
6657 // we will exit from the main loop below before actually storing
6658 // anything computed from that data.
6659 "cmp %w[col], w8\n" // Have we finished the last column?
6660 "bge 5f\n" // If yes, just carry on without updating the column pointer.
6661 // Not finished last column: then advance to next column.
6662 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
6663 "5:\n"
6664
6665 // Set the LHS and RHS data pointers to the start of the columns just
6666 // computed.
6667 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
6668 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
6669
6670 // Load some parameters needed for the end work on current block.
6671 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
6672 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
6673
6674 // Offset these base pointers as needed given the current row, col.
6675 "add x5, x1, %x[row], lsl #2\n"
6676
6677 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
6678 "csel x1, x1, x5, eq\n"
6679
6680 // Load 8 bias values.
6681 "ld1 {v14.4s}, [x1], #16\n"
6682 "ld1 {v15.4s}, [x1]\n"
6683
6684 // Now that we know what LHS and RHS data the next iteration of the
6685 // main loop will need to load, we start loading the first 32 bytes of
6686 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
6687 // in the rest of the work on the current block.
6688 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6689 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6690 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
6691 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6692
6693 // Perform the bias-addition (per the above, we have just folded into
6694 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
6695 "fadd v16.4s, v16.4s, v14.4s\n"
6696 "fadd v17.4s, v17.4s, v15.4s\n"
6697 "fadd v18.4s, v18.4s, v14.4s\n"
6698 "fadd v19.4s, v19.4s, v15.4s\n"
6699 "fadd v20.4s, v20.4s, v14.4s\n"
6700 "fadd v21.4s, v21.4s, v15.4s\n"
6701 "fadd v22.4s, v22.4s, v14.4s\n"
6702 "fadd v23.4s, v23.4s, v15.4s\n"
6703 "fadd v24.4s, v24.4s, v14.4s\n"
6704 "fadd v25.4s, v25.4s, v15.4s\n"
6705 "fadd v26.4s, v26.4s, v14.4s\n"
6706 "fadd v27.4s, v27.4s, v15.4s\n"
6707 "fadd v28.4s, v28.4s, v14.4s\n"
6708 "fadd v29.4s, v29.4s, v15.4s\n"
6709 "fadd v30.4s, v30.4s, v14.4s\n"
6710 "fadd v31.4s, v31.4s, v15.4s\n"
6711
6712 // Load the clamp_min, clamp_max bounds
6713 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
6714 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
6715 "dup v14.4s, w2\n" // clamp_min
6716 "dup v15.4s, w3\n" // clamp_max
6717
6718 // Apply the clamp_min bound
6719 "fmax v16.4s, v16.4s, v14.4s\n"
6720 "fmax v17.4s, v17.4s, v14.4s\n"
6721 "fmax v18.4s, v18.4s, v14.4s\n"
6722 "fmax v19.4s, v19.4s, v14.4s\n"
6723 "fmax v20.4s, v20.4s, v14.4s\n"
6724 "fmax v21.4s, v21.4s, v14.4s\n"
6725 "fmax v22.4s, v22.4s, v14.4s\n"
6726 "fmax v23.4s, v23.4s, v14.4s\n"
6727 "fmax v24.4s, v24.4s, v14.4s\n"
6728 "fmax v25.4s, v25.4s, v14.4s\n"
6729 "fmax v26.4s, v26.4s, v14.4s\n"
6730 "fmax v27.4s, v27.4s, v14.4s\n"
6731 "fmax v28.4s, v28.4s, v14.4s\n"
6732 "fmax v29.4s, v29.4s, v14.4s\n"
6733 "fmax v30.4s, v30.4s, v14.4s\n"
6734 "fmax v31.4s, v31.4s, v14.4s\n"
6735
6736 // Apply the clamp_max bound
6737 "fmin v16.4s, v16.4s, v15.4s\n"
6738 "fmin v17.4s, v17.4s, v15.4s\n"
6739 "fmin v18.4s, v18.4s, v15.4s\n"
6740 "fmin v19.4s, v19.4s, v15.4s\n"
6741 "fmin v20.4s, v20.4s, v15.4s\n"
6742 "fmin v21.4s, v21.4s, v15.4s\n"
6743 "fmin v22.4s, v22.4s, v15.4s\n"
6744 "fmin v23.4s, v23.4s, v15.4s\n"
6745 "fmin v24.4s, v24.4s, v15.4s\n"
6746 "fmin v25.4s, v25.4s, v15.4s\n"
6747 "fmin v26.4s, v26.4s, v15.4s\n"
6748 "fmin v27.4s, v27.4s, v15.4s\n"
6749 "fmin v28.4s, v28.4s, v15.4s\n"
6750 "fmin v29.4s, v29.4s, v15.4s\n"
6751 "fmin v30.4s, v30.4s, v15.4s\n"
6752 "fmin v31.4s, v31.4s, v15.4s\n"
6753
6754 // Compute how much of the 8x8 block of destination 8bit values that
6755 // we have computed, fit in the destination matrix. Typically, all of
6756 // it fits, but when the destination matrix shape is not a multiple
6757 // of 8x8, there are some 8x8 blocks along the boundaries that do
6758 // not fit entirely.
6759 "sub w1, %w[dst_rows], %w[row]\n"
6760 "sub w2, %w[dst_cols], %w[col]\n"
6761 "mov w3, #8\n"
6762 "cmp w1, #8\n"
6763 // Compute w1 = how many rows of the 8x8 block fit
6764 "csel w1, w1, w3, le\n"
6765 "cmp w2, #8\n"
6766 // Compute w2 = how many cols of the 8x8 block fit
6767 "csel w2, w2, w3, le\n"
6768
6769 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
6770 "cmp w1, w3\n"
6771 "ccmp w2, w3, 0, eq\n"
6772 // Yes, all of the 8x8 block fits, go to fast path.
6773 "beq 30f\n"
6774 // Not all of the 8x8 block fits.
6775 // Set (x3 address, x4 stride) to write to dst_tmp_buf
6776 "mov x3, %[dst_tmp_buf]\n"
6777 "mov x4, #32\n"
6778 "b 31f\n"
6779 "30:\n"
6780 // Yes, all of the 8x8 block fits.
6781 // Set (x3 address, x4 stride) to write directly to destination matrix.
6782 "mov x3, %[dst_ptr]\n"
6783 "mov x4, x11\n"
6784 "31:\n"
6785
6786 // Write our 8bit values to the destination described by
6787 // (x3 address, x4 stride).
6788 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6789 "str q16, [x3, #0]\n"
6790 "str q17, [x3, #16]\n"
6791 "add x3, x3, x4\n"
6792 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6793 RUY_MAKE_ZERO(v16)
6794 RUY_MAKE_ZERO(v17)
6795 "str q18, [x3, #0]\n"
6796 "str q19, [x3, #16]\n"
6797 "add x3, x3, x4\n"
6798 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6799 RUY_MAKE_ZERO(v18)
6800 RUY_MAKE_ZERO(v19)
6801 "str q20, [x3, #0]\n"
6802 "str q21, [x3, #16]\n"
6803 "add x3, x3, x4\n"
6804 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6805 RUY_MAKE_ZERO(v20)
6806 RUY_MAKE_ZERO(v21)
6807 "str q22, [x3, #0]\n"
6808 "str q23, [x3, #16]\n"
6809 "add x3, x3, x4\n"
6810 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6811 RUY_MAKE_ZERO(v22)
6812 RUY_MAKE_ZERO(v23)
6813 "str q24, [x3, #0]\n"
6814 "str q25, [x3, #16]\n"
6815 "add x3, x3, x4\n"
6816 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6817 RUY_MAKE_ZERO(v24)
6818 RUY_MAKE_ZERO(v25)
6819 "str q26, [x3, #0]\n"
6820 "str q27, [x3, #16]\n"
6821 "add x3, x3, x4\n"
6822 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6823 RUY_MAKE_ZERO(v26)
6824 RUY_MAKE_ZERO(v27)
6825 "str q28, [x3, #0]\n"
6826 "str q29, [x3, #16]\n"
6827 "add x3, x3, x4\n"
6828 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6829 RUY_MAKE_ZERO(v28)
6830 RUY_MAKE_ZERO(v29)
6831 "str q30, [x3, #0]\n"
6832 "str q31, [x3, #16]\n"
6833 RUY_MAKE_ZERO(v30)
6834 RUY_MAKE_ZERO(v31)
6835
6836 // If all of the 8x8 block fits, we just finished writing it to the
6837 // destination, so we skip the next part.
6838 "beq 41f\n"
6839 // Not all of the 8x8 block fits in the destination matrix. We just
6840 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
6841 // it to copy into the destination matrix the part that fits.
6842 "mov x3, %[dst_tmp_buf]\n"
6843 "mov x4, %[dst_ptr]\n"
6844 "mov w6, #0\n"
6845 "50:\n"
6846 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6847 "mov w5, #0\n"
6848 "51:\n"
6849 "ldr w7, [x3, x5, lsl #2]\n"
6850 "str w7, [x4, x5, lsl #2]\n"
6851 "add w5, w5, #1\n"
6852 "cmp w5, w1\n"
6853 "blt 51b\n"
6854 "add w6, w6, #1\n"
6855 "add x3, x3, #32\n"
6856 "add x4, x4, x11\n"
6857 "cmp w6, w2\n"
6858 "blt 50b\n"
6859 "41:\n"
6860 "add %[dst_ptr], %[dst_ptr], #32\n"
6861 // At this point we have completely finished writing values to the
6862 // destination matrix for the current block.
6863
6864 // Reload some params --- we had used x5 -- x7 for a few other things
6865 // since the last time we had loaded them.
6866 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6867 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6868 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6869
6870 // Move to the next block of the destination matrix, for the next iter
6871 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
6872 // been updated earlier.
6873 // Have we reached the end row?
6874 "cmp %w[row], w7\n"
6875 "beq 20f\n" // yes, end row.
6876 // Not end row. Move to the next row.
6877 "add %w[row], %w[row], #8\n"
6878 "b 21f\n"
6879 "20:\n"
6880 // Was already at end row.
6881 "mov %w[row], w6\n" // Move back to first row.
6882 "add %w[col], %w[col], #8\n" // Move to the next column.
6883 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
6884 "mov %[dst_ptr], %[dst_col_ptr]\n"
6885 "21:\n"
6886
6887 // Main loop exit condition: have we hit the end column?
6888 "cmp %w[col], w8\n"
6889
6890 // w1 is the number of levels of depth that we have already loaded
6891 // LHS and RHS data for. Corresponding to the initial ld1 instructions
6892 // above, this is currently 1.
6893 "mov w1, #1\n"
6894
6895 "ble 1b\n"
6896
6897 // clang-format on
6898
6899 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
6900 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
6901 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
6902 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
6903 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
6904 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
6905 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
6906 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
6907 "v26", "v27", "v28", "v29", "v30", "v31");
6908 }
6909
6910 // Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not
6911 // support dotprod (while dotprod by itself is not relevant to floating-point,
6912 // this additional bit of information that we have about the target happens to
6913 // be useful here).
6914 //
6915 // So a typical target CPU here would be ARM Cortex-A53 or the original
6916 // Cortex-A55.
6917 //
6918 // This kernel is similar to and inspired by gemmlowp's
6919 // NEON_64bit_GEMM_Float32_WithScalar_A53.
6920 // which was contributed by David Mansell with very helpful
6921 // comments. Specifically, see this comment about tuning for Cortex-A53:
6922 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
KernelFloatNeonInOrder(const KernelParamsFloat<8,8> & params)6923 void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) {
6924 profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
6925
6926 CheckOffsetsInKernelParamsFloat(params);
6927
6928 const float* lhs_col_ptr = params.lhs_base_ptr;
6929 const float* rhs_col_ptr = params.rhs_base_ptr;
6930 const float* lhs_ptr = lhs_col_ptr;
6931 const float* rhs_ptr = rhs_col_ptr;
6932 float* dst_col_ptr = params.dst_base_ptr;
6933 float* dst_ptr = dst_col_ptr;
6934 int row = params.start_row;
6935 int col = params.start_col;
6936
6937 // The asm kernel below has the following NEON register allocation:
6938 //
6939 // v16 -- v31 are accumulators.
6940 // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
6941 //
6942 // RHS 1x8 block
6943 // /-----------------------------------------\
6944 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
6945 // \-----------------------------------------/
6946 // LHS 8x1 block
6947 // /---------------------\ /-----------------------------------------\
6948 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
6949 // | ... | | ... ... |
6950 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
6951 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
6952 // | ... | | ... ... |
6953 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
6954 // \---------------------/ \-----------------------------------------/
6955 // accumulators 8x8 block
6956 //
6957 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
6958 // we did not observe a benefit of such partial unrolling on in-order CPUs.
6959 //
6960 // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
6961 // for the post-accumulation part of the kernel.
6962 asm volatile(
6963 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
6964
6965 // clang-format off
6966
6967 // Load some parameters into registers.
6968 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6969 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6970 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6971 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
6972 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
6973 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
6974 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
6975 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
6976
6977
6978 // Clear accumulators.
6979 RUY_MAKE_ZERO(v16)
6980 // Load the first 32 bytes of LHS and RHS data.
6981 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6982 RUY_MAKE_ZERO(v17)
6983 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6984 RUY_MAKE_ZERO(v18)
6985 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
6986 RUY_MAKE_ZERO(v19)
6987 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6988 RUY_MAKE_ZERO(v20)
6989 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
6990 RUY_MAKE_ZERO(v21)
6991 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
6992 RUY_MAKE_ZERO(v22)
6993 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
6994 RUY_MAKE_ZERO(v23)
6995 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
6996 RUY_MAKE_ZERO(v24)
6997 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
6998 RUY_MAKE_ZERO(v25)
6999 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
7000 RUY_MAKE_ZERO(v26)
7001 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7002 RUY_MAKE_ZERO(v27)
7003 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7004 RUY_MAKE_ZERO(v28)
7005 RUY_MAKE_ZERO(v29)
7006 RUY_MAKE_ZERO(v30)
7007 RUY_MAKE_ZERO(v31)
7008
7009 // w1 is the number of levels of depth that remain to load
7010 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7011 // above, this is currently depth - 1.
7012 "sub w1, w12, #1\n"
7013
7014 // Main loop of the whole GEMM, over rows and columns of the
7015 // destination matrix.
7016 "1:\n"
7017
7018 "cmp w1, #0\n"
7019 "fmla v16.4s, v0.4s, v2.s[0]\n"
7020 "fmla v18.4s, v0.4s, v2.s[1]\n"
7021 "fmla v20.4s, v0.4s, v2.s[2]\n"
7022 "fmla v22.4s, v0.4s, v2.s[3]\n"
7023
7024 // Accumulation loop
7025 "beq 79f\n"
7026
7027 "2:\n"
7028
7029 "fmla v24.4s, v0.4s, v3.s[0]\n"
7030 "ldr x2, [%[lhs_ptr], #8]\n"
7031 "fmla v26.4s, v0.4s, v3.s[1]\n"
7032 "ldr x3, [%[lhs_ptr], #24]\n"
7033 "fmla v28.4s, v0.4s, v3.s[2]\n"
7034 "ldr x5, [%[rhs_ptr], #24]\n"
7035 "fmla v30.4s, v0.4s, v3.s[3]\n"
7036 "ldr x4, [%[rhs_ptr], #8]\n"
7037 "fmla v25.4s, v1.4s, v3.s[0]\n"
7038 "subs w1, w1, #1\n"
7039 "ldr d0, [%[lhs_ptr]], #32\n"
7040 "fmla v27.4s, v1.4s, v3.s[1]\n"
7041 "fmla v29.4s, v1.4s, v3.s[2]\n"
7042 "fmla v31.4s, v1.4s, v3.s[3]\n"
7043 "ins v0.d[1], x2\n"
7044 "ldr d3, [%[rhs_ptr], #16]\n"
7045 "fmla v17.4s, v1.4s, v2.s[0]\n"
7046 "fmla v19.4s, v1.4s, v2.s[1]\n"
7047 "ins v3.d[1], x5\n"
7048 "ldr d4, [%[rhs_ptr]], #32\n"
7049 "fmla v21.4s, v1.4s, v2.s[2]\n"
7050 "fmla v23.4s, v1.4s, v2.s[3]\n"
7051 "fmla v16.4s, v0.4s, v4.s[0]\n"
7052 "ins v4.d[1], x4\n"
7053 "ldr d1, [%[lhs_ptr], #-16]\n"
7054 "fmla v18.4s, v0.4s, v4.s[1]\n"
7055 "fmla v20.4s, v0.4s, v4.s[2]\n"
7056 "ins v1.d[1], x3\n"
7057 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7058 "mov v2.16b, v4.16b\n"
7059 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7060 "fmla v22.4s, v0.4s, v4.s[3]\n"
7061 "bne 2b\n"
7062
7063 "79:\n"
7064
7065 // End of the inner loop on depth. Now perform the remaining
7066 // multiply-adds of the last level of depth, for which the LHS
7067 // and RHS data is already loaded.
7068
7069 "fmla v24.4s, v0.4s, v3.s[0]\n"
7070 "fmla v26.4s, v0.4s, v3.s[1]\n"
7071 "fmla v28.4s, v0.4s, v3.s[2]\n"
7072 "fmla v30.4s, v0.4s, v3.s[3]\n"
7073 "fmla v25.4s, v1.4s, v3.s[0]\n"
7074 "fmla v27.4s, v1.4s, v3.s[1]\n"
7075 "fmla v29.4s, v1.4s, v3.s[2]\n"
7076 "fmla v31.4s, v1.4s, v3.s[3]\n"
7077 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7078 "fmla v17.4s, v1.4s, v2.s[0]\n"
7079 "fmla v19.4s, v1.4s, v2.s[1]\n"
7080 "fmla v21.4s, v1.4s, v2.s[2]\n"
7081 "fmla v23.4s, v1.4s, v2.s[3]\n"
7082
7083 // End of accumulation. The registers v16 -- v31 contain the final
7084 // int32 accumulator values of the current 8x8 destination block.
7085 // We now have to compute the final 8-bit values from these int32
7086 // accumulators, and advance to the next 8x8 block. We intertwine
7087 // these two aspects whenever possible for optimal pipelining, both
7088 // at the data flow level (prefetch data for next block as early as
7089 // possible) and instruction pipelining level (some of the next-block
7090 // work can dual-issue with some of the final work on the current
7091 // block).
7092
7093 // Logic to advance to the next block in preparation for the next
7094 // iteration of the main loop. For now, we only want to compute
7095 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
7096 // not yet ready to update the values of row and col, as we still need
7097 // the current values for the rest of the work on the current block.
7098
7099 "cmp %w[row], w7\n" // Have we finished the last row?
7100 "bge 4f\n" // If finished last row, go to 4
7101 // Not finished last row: then advance to next row.
7102 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
7103 "b 5f\n"
7104 "4:\n" // Finished last row...
7105 "mov %[lhs_col_ptr], x5\n" // Go back to first row
7106 // Now we need to advance to the next column. If we already
7107 // finished the last column, then in principle we are done, however
7108 // we can't just return here, as we need to allow the end work of the
7109 // current block to complete. The good news is that at this point it
7110 // doesn't matter what data we load for the next column, since
7111 // we will exit from the main loop below before actually storing
7112 // anything computed from that data.
7113 "cmp %w[col], w8\n" // Have we finished the last column?
7114 "bge 5f\n" // If yes, just carry on without updating the column pointer.
7115 // Not finished last column: then advance to next column.
7116 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
7117 "5:\n"
7118
7119 // Set the LHS and RHS data pointers to the start of the columns just
7120 // computed.
7121 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
7122 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
7123
7124 // Load some parameters needed for the end work on current block.
7125 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
7126 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
7127
7128 // Offset these base pointers as needed given the current row, col.
7129 "add x5, x1, %x[row], lsl #2\n"
7130
7131 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
7132 "csel x1, x1, x5, eq\n"
7133
7134 // Load 8 bias values.
7135 "ld1 {v14.4s}, [x1], #16\n"
7136 "ld1 {v15.4s}, [x1]\n"
7137
7138 // Now that we know what LHS and RHS data the next iteration of the
7139 // main loop will need to load, we start loading the first 32 bytes of
7140 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
7141 // in the rest of the work on the current block.
7142 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7143 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7144 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7145 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7146
7147 // Perform the bias-addition (per the above, we have just folded into
7148 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
7149 "fadd v16.4s, v16.4s, v14.4s\n"
7150 "fadd v17.4s, v17.4s, v15.4s\n"
7151 "fadd v18.4s, v18.4s, v14.4s\n"
7152 "fadd v19.4s, v19.4s, v15.4s\n"
7153 "fadd v20.4s, v20.4s, v14.4s\n"
7154 "fadd v21.4s, v21.4s, v15.4s\n"
7155 "fadd v22.4s, v22.4s, v14.4s\n"
7156 "fadd v23.4s, v23.4s, v15.4s\n"
7157 "fadd v24.4s, v24.4s, v14.4s\n"
7158 "fadd v25.4s, v25.4s, v15.4s\n"
7159 "fadd v26.4s, v26.4s, v14.4s\n"
7160 "fadd v27.4s, v27.4s, v15.4s\n"
7161 "fadd v28.4s, v28.4s, v14.4s\n"
7162 "fadd v29.4s, v29.4s, v15.4s\n"
7163 "fadd v30.4s, v30.4s, v14.4s\n"
7164 "fadd v31.4s, v31.4s, v15.4s\n"
7165
7166 // Load the clamp_min, clamp_max bounds
7167 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
7168 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
7169 "dup v14.4s, w2\n" // clamp_min
7170 "dup v15.4s, w3\n" // clamp_max
7171
7172 // Apply the clamp_min bound
7173 "fmax v16.4s, v16.4s, v14.4s\n"
7174 "fmax v17.4s, v17.4s, v14.4s\n"
7175 "fmax v18.4s, v18.4s, v14.4s\n"
7176 "fmax v19.4s, v19.4s, v14.4s\n"
7177 "fmax v20.4s, v20.4s, v14.4s\n"
7178 "fmax v21.4s, v21.4s, v14.4s\n"
7179 "fmax v22.4s, v22.4s, v14.4s\n"
7180 "fmax v23.4s, v23.4s, v14.4s\n"
7181 "fmax v24.4s, v24.4s, v14.4s\n"
7182 "fmax v25.4s, v25.4s, v14.4s\n"
7183 "fmax v26.4s, v26.4s, v14.4s\n"
7184 "fmax v27.4s, v27.4s, v14.4s\n"
7185 "fmax v28.4s, v28.4s, v14.4s\n"
7186 "fmax v29.4s, v29.4s, v14.4s\n"
7187 "fmax v30.4s, v30.4s, v14.4s\n"
7188 "fmax v31.4s, v31.4s, v14.4s\n"
7189
7190 // Apply the clamp_max bound
7191 "fmin v16.4s, v16.4s, v15.4s\n"
7192 "fmin v17.4s, v17.4s, v15.4s\n"
7193 "fmin v18.4s, v18.4s, v15.4s\n"
7194 "fmin v19.4s, v19.4s, v15.4s\n"
7195 "fmin v20.4s, v20.4s, v15.4s\n"
7196 "fmin v21.4s, v21.4s, v15.4s\n"
7197 "fmin v22.4s, v22.4s, v15.4s\n"
7198 "fmin v23.4s, v23.4s, v15.4s\n"
7199 "fmin v24.4s, v24.4s, v15.4s\n"
7200 "fmin v25.4s, v25.4s, v15.4s\n"
7201 "fmin v26.4s, v26.4s, v15.4s\n"
7202 "fmin v27.4s, v27.4s, v15.4s\n"
7203 "fmin v28.4s, v28.4s, v15.4s\n"
7204 "fmin v29.4s, v29.4s, v15.4s\n"
7205 "fmin v30.4s, v30.4s, v15.4s\n"
7206 "fmin v31.4s, v31.4s, v15.4s\n"
7207
7208 // Compute how much of the 8x8 block of destination 8bit values that
7209 // we have computed, fit in the destination matrix. Typically, all of
7210 // it fits, but when the destination matrix shape is not a multiple
7211 // of 8x8, there are some 8x8 blocks along the boundaries that do
7212 // not fit entirely.
7213 "sub w1, %w[dst_rows], %w[row]\n"
7214 "sub w2, %w[dst_cols], %w[col]\n"
7215 "mov w3, #8\n"
7216 "cmp w1, #8\n"
7217 // Compute w1 = how many rows of the 8x8 block fit
7218 "csel w1, w1, w3, le\n"
7219 "cmp w2, #8\n"
7220 // Compute w2 = how many cols of the 8x8 block fit
7221 "csel w2, w2, w3, le\n"
7222
7223 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
7224 "cmp w1, w3\n"
7225 "ccmp w2, w3, 0, eq\n"
7226 // Yes, all of the 8x8 block fits, go to fast path.
7227 "beq 30f\n"
7228 // Not all of the 8x8 block fits.
7229 // Set (x3 address, x4 stride) to write to dst_tmp_buf
7230 "mov x3, %[dst_tmp_buf]\n"
7231 "mov x4, #32\n"
7232 "b 31f\n"
7233 "30:\n"
7234 // Yes, all of the 8x8 block fits.
7235 // Set (x3 address, x4 stride) to write directly to destination matrix.
7236 "mov x3, %[dst_ptr]\n"
7237 "mov x4, x11\n"
7238 "31:\n"
7239
7240 // Write our 8bit values to the destination described by
7241 // (x3 address, x4 stride).
7242 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7243 "str q16, [x3, #0]\n"
7244 "str q17, [x3, #16]\n"
7245 "add x3, x3, x4\n"
7246 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7247 RUY_MAKE_ZERO(v16)
7248 RUY_MAKE_ZERO(v17)
7249 "str q18, [x3, #0]\n"
7250 "str q19, [x3, #16]\n"
7251 "add x3, x3, x4\n"
7252 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7253 RUY_MAKE_ZERO(v18)
7254 RUY_MAKE_ZERO(v19)
7255 "str q20, [x3, #0]\n"
7256 "str q21, [x3, #16]\n"
7257 "add x3, x3, x4\n"
7258 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7259 RUY_MAKE_ZERO(v20)
7260 RUY_MAKE_ZERO(v21)
7261 "str q22, [x3, #0]\n"
7262 "str q23, [x3, #16]\n"
7263 "add x3, x3, x4\n"
7264 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7265 RUY_MAKE_ZERO(v22)
7266 RUY_MAKE_ZERO(v23)
7267 "str q24, [x3, #0]\n"
7268 "str q25, [x3, #16]\n"
7269 "add x3, x3, x4\n"
7270 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7271 RUY_MAKE_ZERO(v24)
7272 RUY_MAKE_ZERO(v25)
7273 "str q26, [x3, #0]\n"
7274 "str q27, [x3, #16]\n"
7275 "add x3, x3, x4\n"
7276 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7277 RUY_MAKE_ZERO(v26)
7278 RUY_MAKE_ZERO(v27)
7279 "str q28, [x3, #0]\n"
7280 "str q29, [x3, #16]\n"
7281 "add x3, x3, x4\n"
7282 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7283 RUY_MAKE_ZERO(v28)
7284 RUY_MAKE_ZERO(v29)
7285 "str q30, [x3, #0]\n"
7286 "str q31, [x3, #16]\n"
7287 RUY_MAKE_ZERO(v30)
7288 RUY_MAKE_ZERO(v31)
7289
7290 // If all of the 8x8 block fits, we just finished writing it to the
7291 // destination, so we skip the next part.
7292 "beq 41f\n"
7293 // Not all of the 8x8 block fits in the destination matrix. We just
7294 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
7295 // it to copy into the destination matrix the part that fits.
7296 "mov x3, %[dst_tmp_buf]\n"
7297 "mov x4, %[dst_ptr]\n"
7298 "mov w6, #0\n"
7299 "50:\n"
7300 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
7301 "mov w5, #0\n"
7302 "51:\n"
7303 "ldr w7, [x3, x5, lsl #2]\n"
7304 "str w7, [x4, x5, lsl #2]\n"
7305 "add w5, w5, #1\n"
7306 "cmp w5, w1\n"
7307 "blt 51b\n"
7308 "add w6, w6, #1\n"
7309 "add x3, x3, #32\n"
7310 "add x4, x4, x11\n"
7311 "cmp w6, w2\n"
7312 "blt 50b\n"
7313 "41:\n"
7314 "add %[dst_ptr], %[dst_ptr], #32\n"
7315 // At this point we have completely finished writing values to the
7316 // destination matrix for the current block.
7317
7318 // Reload some params --- we had used x5 -- x7 for a few other things
7319 // since the last time we had loaded them.
7320 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7321 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7322 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7323
7324 // Move to the next block of the destination matrix, for the next iter
7325 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
7326 // been updated earlier.
7327 // Have we reached the end row?
7328 "cmp %w[row], w7\n"
7329 "beq 20f\n" // yes, end row.
7330 // Not end row. Move to the next row.
7331 "add %w[row], %w[row], #8\n"
7332 "b 21f\n"
7333 "20:\n"
7334 // Was already at end row.
7335 "mov %w[row], w6\n" // Move back to first row.
7336 "add %w[col], %w[col], #8\n" // Move to the next column.
7337 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
7338 "mov %[dst_ptr], %[dst_col_ptr]\n"
7339 "21:\n"
7340
7341 // Main loop exit condition: have we hit the end column?
7342 "cmp %w[col], w8\n"
7343
7344 // w1 is the number of levels of depth that remain to load
7345 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7346 // above, this is currently depth - 1.
7347 "sub w1, w12, #1\n"
7348
7349 "ble 1b\n"
7350
7351 // clang-format on
7352
7353 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
7354 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
7355 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
7356 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
7357 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
7358 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
7359 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
7360 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
7361 "v26", "v27", "v28", "v29", "v30", "v31");
7362 }
7363
7364 // Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do
7365 // support dotprod (while dotprod by itself is not relevant to floating-point,
7366 // this additional bit of information that we have about the target happens to
7367 // be useful here).
7368 //
7369 // So a typical target CPU here would be ARM Cortex-A55r1.
7370 //
7371 // This kernel is similar to and inspired by gemmlowp's
7372 // NEON_64bit_GEMM_Float32_WithScalar_A55r1.
7373 // which was contributed by David Mansell with very helpful
7374 // comments. Specifically, see this comment about tuning for Cortex-A55r1:
7375 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8,8> & params)7376 void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) {
7377 profiler::ScopeLabel label(
7378 "Kernel (kNeonDotprod, optimized for in-order cores)");
7379
7380 CheckOffsetsInKernelParamsFloat(params);
7381
7382 const float* lhs_col_ptr = params.lhs_base_ptr;
7383 const float* rhs_col_ptr = params.rhs_base_ptr;
7384 const float* lhs_ptr = lhs_col_ptr;
7385 const float* rhs_ptr = rhs_col_ptr;
7386 float* dst_col_ptr = params.dst_base_ptr;
7387 float* dst_ptr = dst_col_ptr;
7388 int row = params.start_row;
7389 int col = params.start_col;
7390
7391 // The asm kernel below has the following NEON register allocation:
7392 //
7393 // v16 -- v31 are accumulators.
7394 // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
7395 //
7396 // RHS 1x8 block
7397 // /-----------------------------------------\
7398 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
7399 // \-----------------------------------------/
7400 // LHS 8x1 block
7401 // /---------------------\ /-----------------------------------------\
7402 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
7403 // | ... | | ... ... |
7404 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
7405 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
7406 // | ... | | ... ... |
7407 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
7408 // \---------------------/ \-----------------------------------------/
7409 // accumulators 8x8 block
7410 //
7411 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
7412 // we did not observe a benefit of such partial unrolling on in-order CPUs.
7413 //
7414 // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
7415 // for the post-accumulation part of the kernel.
7416 asm volatile(
7417 #define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
7418
7419 // clang-format off
7420
7421 // Load some parameters into registers.
7422 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7423 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7424 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7425 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
7426 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
7427 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
7428 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
7429 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
7430
7431
7432 // Clear accumulators.
7433 RUY_MAKE_ZERO(v16)
7434 // Load the first 32 bytes of LHS and RHS data.
7435 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7436 RUY_MAKE_ZERO(v17)
7437 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7438 RUY_MAKE_ZERO(v18)
7439 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7440 RUY_MAKE_ZERO(v19)
7441 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7442 RUY_MAKE_ZERO(v20)
7443 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
7444 RUY_MAKE_ZERO(v21)
7445 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
7446 RUY_MAKE_ZERO(v22)
7447 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
7448 RUY_MAKE_ZERO(v23)
7449 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
7450 RUY_MAKE_ZERO(v24)
7451 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
7452 RUY_MAKE_ZERO(v25)
7453 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
7454 RUY_MAKE_ZERO(v26)
7455 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7456 RUY_MAKE_ZERO(v27)
7457 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7458 RUY_MAKE_ZERO(v28)
7459 RUY_MAKE_ZERO(v29)
7460 RUY_MAKE_ZERO(v30)
7461 RUY_MAKE_ZERO(v31)
7462
7463 // w1 is the number of levels of depth that remain to load
7464 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7465 // above, this is currently depth - 1.
7466 "sub w1, w12, #1\n"
7467
7468 // Main loop of the whole GEMM, over rows and columns of the
7469 // destination matrix.
7470 "1:\n"
7471
7472 "cmp w1, #0\n"
7473 "fmla v16.4s, v0.4s, v2.s[0]\n"
7474 "fmla v18.4s, v0.4s, v2.s[1]\n"
7475 "fmla v20.4s, v0.4s, v2.s[2]\n"
7476 "fmla v22.4s, v0.4s, v2.s[3]\n"
7477
7478 // Accumulation loop
7479 "beq 79f\n"
7480
7481 "2:\n"
7482
7483 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7484 "fmla v24.4s, v0.4s, v3.s[0]\n"
7485 "ldr x2, [%[lhs_ptr], #8]\n"
7486 "fmla v26.4s, v0.4s, v3.s[1]\n"
7487 "ldr x3, [%[lhs_ptr], #24]\n"
7488 "fmla v28.4s, v0.4s, v3.s[2]\n"
7489 "ldr x5, [%[rhs_ptr], #24]\n"
7490 "fmla v30.4s, v0.4s, v3.s[3]\n"
7491 "ldr d0, [%[lhs_ptr]], #32\n"
7492 "fmla v25.4s, v1.4s, v3.s[0]\n"
7493 "ldr x4, [%[rhs_ptr], #8]\n"
7494 "fmla v27.4s, v1.4s, v3.s[1]\n"
7495 "subs w1, w1, #1\n"
7496 "fmla v29.4s, v1.4s, v3.s[2]\n"
7497 "ins v0.d[1], x2\n"
7498 "fmla v31.4s, v1.4s, v3.s[3]\n"
7499 "ldr d3, [%[rhs_ptr], #16]\n"
7500 "fmla v17.4s, v1.4s, v2.s[0]\n"
7501 "ins v3.d[1], x5\n"
7502 "fmla v19.4s, v1.4s, v2.s[1]\n"
7503 "ldr d4, [%[rhs_ptr]], #32\n"
7504 "fmla v21.4s, v1.4s, v2.s[2]\n"
7505 "ins v4.d[1], x4\n"
7506 "fmla v23.4s, v1.4s, v2.s[3]\n"
7507 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7508 "fmla v16.4s, v0.4s, v4.s[0]\n"
7509 "ldr d1, [%[lhs_ptr], #-16]\n"
7510 "fmla v18.4s, v0.4s, v4.s[1]\n"
7511 "ins v1.d[1], x3\n"
7512 "fmla v20.4s, v0.4s, v4.s[2]\n"
7513 "mov v2.16b, v4.16b\n"
7514 "fmla v22.4s, v0.4s, v4.s[3]\n"
7515 "bne 2b\n"
7516
7517 "79:\n"
7518
7519 // End of the inner loop on depth. Now perform the remaining
7520 // multiply-adds of the last level of depth, for which the LHS
7521 // and RHS data is already loaded.
7522
7523 "fmla v24.4s, v0.4s, v3.s[0]\n"
7524 "fmla v26.4s, v0.4s, v3.s[1]\n"
7525 "fmla v28.4s, v0.4s, v3.s[2]\n"
7526 "fmla v30.4s, v0.4s, v3.s[3]\n"
7527 "fmla v25.4s, v1.4s, v3.s[0]\n"
7528 "fmla v27.4s, v1.4s, v3.s[1]\n"
7529 "fmla v29.4s, v1.4s, v3.s[2]\n"
7530 "fmla v31.4s, v1.4s, v3.s[3]\n"
7531 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7532 "fmla v17.4s, v1.4s, v2.s[0]\n"
7533 "fmla v19.4s, v1.4s, v2.s[1]\n"
7534 "fmla v21.4s, v1.4s, v2.s[2]\n"
7535 "fmla v23.4s, v1.4s, v2.s[3]\n"
7536
7537 // End of accumulation. The registers v16 -- v31 contain the final
7538 // int32 accumulator values of the current 8x8 destination block.
7539 // We now have to compute the final 8-bit values from these int32
7540 // accumulators, and advance to the next 8x8 block. We intertwine
7541 // these two aspects whenever possible for optimal pipelining, both
7542 // at the data flow level (prefetch data for next block as early as
7543 // possible) and instruction pipelining level (some of the next-block
7544 // work can dual-issue with some of the final work on the current
7545 // block).
7546
7547 // Logic to advance to the next block in preparation for the next
7548 // iteration of the main loop. For now, we only want to compute
7549 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
7550 // not yet ready to update the values of row and col, as we still need
7551 // the current values for the rest of the work on the current block.
7552
7553 "cmp %w[row], w7\n" // Have we finished the last row?
7554 "bge 4f\n" // If finished last row, go to 4
7555 // Not finished last row: then advance to next row.
7556 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
7557 "b 5f\n"
7558 "4:\n" // Finished last row...
7559 "mov %[lhs_col_ptr], x5\n" // Go back to first row
7560 // Now we need to advance to the next column. If we already
7561 // finished the last column, then in principle we are done, however
7562 // we can't just return here, as we need to allow the end work of the
7563 // current block to complete. The good news is that at this point it
7564 // doesn't matter what data we load for the next column, since
7565 // we will exit from the main loop below before actually storing
7566 // anything computed from that data.
7567 "cmp %w[col], w8\n" // Have we finished the last column?
7568 "bge 5f\n" // If yes, just carry on without updating the column pointer.
7569 // Not finished last column: then advance to next column.
7570 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
7571 "5:\n"
7572
7573 // Set the LHS and RHS data pointers to the start of the columns just
7574 // computed.
7575 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
7576 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
7577
7578 // Load some parameters needed for the end work on current block.
7579 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
7580 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
7581
7582 // Offset these base pointers as needed given the current row, col.
7583 "add x5, x1, %x[row], lsl #2\n"
7584
7585 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
7586 "csel x1, x1, x5, eq\n"
7587
7588 // Load 8 bias values.
7589 "ld1 {v14.4s}, [x1], #16\n"
7590 "ld1 {v15.4s}, [x1]\n"
7591
7592 // Now that we know what LHS and RHS data the next iteration of the
7593 // main loop will need to load, we start loading the first 32 bytes of
7594 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
7595 // in the rest of the work on the current block.
7596 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7597 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7598 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7599 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7600
7601 // Perform the bias-addition (per the above, we have just folded into
7602 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
7603 "fadd v16.4s, v16.4s, v14.4s\n"
7604 "fadd v17.4s, v17.4s, v15.4s\n"
7605 "fadd v18.4s, v18.4s, v14.4s\n"
7606 "fadd v19.4s, v19.4s, v15.4s\n"
7607 "fadd v20.4s, v20.4s, v14.4s\n"
7608 "fadd v21.4s, v21.4s, v15.4s\n"
7609 "fadd v22.4s, v22.4s, v14.4s\n"
7610 "fadd v23.4s, v23.4s, v15.4s\n"
7611 "fadd v24.4s, v24.4s, v14.4s\n"
7612 "fadd v25.4s, v25.4s, v15.4s\n"
7613 "fadd v26.4s, v26.4s, v14.4s\n"
7614 "fadd v27.4s, v27.4s, v15.4s\n"
7615 "fadd v28.4s, v28.4s, v14.4s\n"
7616 "fadd v29.4s, v29.4s, v15.4s\n"
7617 "fadd v30.4s, v30.4s, v14.4s\n"
7618 "fadd v31.4s, v31.4s, v15.4s\n"
7619
7620 // Load the clamp_min, clamp_max bounds
7621 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
7622 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
7623 "dup v14.4s, w2\n" // clamp_min
7624 "dup v15.4s, w3\n" // clamp_max
7625
7626 // Apply the clamp_min bound
7627 "fmax v16.4s, v16.4s, v14.4s\n"
7628 "fmax v17.4s, v17.4s, v14.4s\n"
7629 "fmax v18.4s, v18.4s, v14.4s\n"
7630 "fmax v19.4s, v19.4s, v14.4s\n"
7631 "fmax v20.4s, v20.4s, v14.4s\n"
7632 "fmax v21.4s, v21.4s, v14.4s\n"
7633 "fmax v22.4s, v22.4s, v14.4s\n"
7634 "fmax v23.4s, v23.4s, v14.4s\n"
7635 "fmax v24.4s, v24.4s, v14.4s\n"
7636 "fmax v25.4s, v25.4s, v14.4s\n"
7637 "fmax v26.4s, v26.4s, v14.4s\n"
7638 "fmax v27.4s, v27.4s, v14.4s\n"
7639 "fmax v28.4s, v28.4s, v14.4s\n"
7640 "fmax v29.4s, v29.4s, v14.4s\n"
7641 "fmax v30.4s, v30.4s, v14.4s\n"
7642 "fmax v31.4s, v31.4s, v14.4s\n"
7643
7644 // Apply the clamp_max bound
7645 "fmin v16.4s, v16.4s, v15.4s\n"
7646 "fmin v17.4s, v17.4s, v15.4s\n"
7647 "fmin v18.4s, v18.4s, v15.4s\n"
7648 "fmin v19.4s, v19.4s, v15.4s\n"
7649 "fmin v20.4s, v20.4s, v15.4s\n"
7650 "fmin v21.4s, v21.4s, v15.4s\n"
7651 "fmin v22.4s, v22.4s, v15.4s\n"
7652 "fmin v23.4s, v23.4s, v15.4s\n"
7653 "fmin v24.4s, v24.4s, v15.4s\n"
7654 "fmin v25.4s, v25.4s, v15.4s\n"
7655 "fmin v26.4s, v26.4s, v15.4s\n"
7656 "fmin v27.4s, v27.4s, v15.4s\n"
7657 "fmin v28.4s, v28.4s, v15.4s\n"
7658 "fmin v29.4s, v29.4s, v15.4s\n"
7659 "fmin v30.4s, v30.4s, v15.4s\n"
7660 "fmin v31.4s, v31.4s, v15.4s\n"
7661
7662 // Compute how much of the 8x8 block of destination 8bit values that
7663 // we have computed, fit in the destination matrix. Typically, all of
7664 // it fits, but when the destination matrix shape is not a multiple
7665 // of 8x8, there are some 8x8 blocks along the boundaries that do
7666 // not fit entirely.
7667 "sub w1, %w[dst_rows], %w[row]\n"
7668 "sub w2, %w[dst_cols], %w[col]\n"
7669 "mov w3, #8\n"
7670 "cmp w1, #8\n"
7671 // Compute w1 = how many rows of the 8x8 block fit
7672 "csel w1, w1, w3, le\n"
7673 "cmp w2, #8\n"
7674 // Compute w2 = how many cols of the 8x8 block fit
7675 "csel w2, w2, w3, le\n"
7676
7677 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
7678 "cmp w1, w3\n"
7679 "ccmp w2, w3, 0, eq\n"
7680 // Yes, all of the 8x8 block fits, go to fast path.
7681 "beq 30f\n"
7682 // Not all of the 8x8 block fits.
7683 // Set (x3 address, x4 stride) to write to dst_tmp_buf
7684 "mov x3, %[dst_tmp_buf]\n"
7685 "mov x4, #32\n"
7686 "b 31f\n"
7687 "30:\n"
7688 // Yes, all of the 8x8 block fits.
7689 // Set (x3 address, x4 stride) to write directly to destination matrix.
7690 "mov x3, %[dst_ptr]\n"
7691 "mov x4, x11\n"
7692 "31:\n"
7693
7694 // Write our 8bit values to the destination described by
7695 // (x3 address, x4 stride).
7696 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7697 "str q16, [x3, #0]\n"
7698 "str q17, [x3, #16]\n"
7699 "add x3, x3, x4\n"
7700 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7701 RUY_MAKE_ZERO(v16)
7702 RUY_MAKE_ZERO(v17)
7703 "str q18, [x3, #0]\n"
7704 "str q19, [x3, #16]\n"
7705 "add x3, x3, x4\n"
7706 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7707 RUY_MAKE_ZERO(v18)
7708 RUY_MAKE_ZERO(v19)
7709 "str q20, [x3, #0]\n"
7710 "str q21, [x3, #16]\n"
7711 "add x3, x3, x4\n"
7712 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7713 RUY_MAKE_ZERO(v20)
7714 RUY_MAKE_ZERO(v21)
7715 "str q22, [x3, #0]\n"
7716 "str q23, [x3, #16]\n"
7717 "add x3, x3, x4\n"
7718 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7719 RUY_MAKE_ZERO(v22)
7720 RUY_MAKE_ZERO(v23)
7721 "str q24, [x3, #0]\n"
7722 "str q25, [x3, #16]\n"
7723 "add x3, x3, x4\n"
7724 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7725 RUY_MAKE_ZERO(v24)
7726 RUY_MAKE_ZERO(v25)
7727 "str q26, [x3, #0]\n"
7728 "str q27, [x3, #16]\n"
7729 "add x3, x3, x4\n"
7730 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7731 RUY_MAKE_ZERO(v26)
7732 RUY_MAKE_ZERO(v27)
7733 "str q28, [x3, #0]\n"
7734 "str q29, [x3, #16]\n"
7735 "add x3, x3, x4\n"
7736 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7737 RUY_MAKE_ZERO(v28)
7738 RUY_MAKE_ZERO(v29)
7739 "str q30, [x3, #0]\n"
7740 "str q31, [x3, #16]\n"
7741 RUY_MAKE_ZERO(v30)
7742 RUY_MAKE_ZERO(v31)
7743
7744 // If all of the 8x8 block fits, we just finished writing it to the
7745 // destination, so we skip the next part.
7746 "beq 41f\n"
7747 // Not all of the 8x8 block fits in the destination matrix. We just
7748 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
7749 // it to copy into the destination matrix the part that fits.
7750 "mov x3, %[dst_tmp_buf]\n"
7751 "mov x4, %[dst_ptr]\n"
7752 "mov w6, #0\n"
7753 "50:\n"
7754 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
7755 "mov w5, #0\n"
7756 "51:\n"
7757 "ldr w7, [x3, x5, lsl #2]\n"
7758 "str w7, [x4, x5, lsl #2]\n"
7759 "add w5, w5, #1\n"
7760 "cmp w5, w1\n"
7761 "blt 51b\n"
7762 "add w6, w6, #1\n"
7763 "add x3, x3, #32\n"
7764 "add x4, x4, x11\n"
7765 "cmp w6, w2\n"
7766 "blt 50b\n"
7767 "41:\n"
7768 "add %[dst_ptr], %[dst_ptr], #32\n"
7769 // At this point we have completely finished writing values to the
7770 // destination matrix for the current block.
7771
7772 // Reload some params --- we had used x5 -- x7 for a few other things
7773 // since the last time we had loaded them.
7774 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7775 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7776 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7777
7778 // Move to the next block of the destination matrix, for the next iter
7779 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
7780 // been updated earlier.
7781 // Have we reached the end row?
7782 "cmp %w[row], w7\n"
7783 "beq 20f\n" // yes, end row.
7784 // Not end row. Move to the next row.
7785 "add %w[row], %w[row], #8\n"
7786 "b 21f\n"
7787 "20:\n"
7788 // Was already at end row.
7789 "mov %w[row], w6\n" // Move back to first row.
7790 "add %w[col], %w[col], #8\n" // Move to the next column.
7791 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
7792 "mov %[dst_ptr], %[dst_col_ptr]\n"
7793 "21:\n"
7794
7795 // Main loop exit condition: have we hit the end column?
7796 "cmp %w[col], w8\n"
7797
7798 // w1 is the number of levels of depth that remain to load
7799 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7800 // above, this is currently depth - 1.
7801 "sub w1, w12, #1\n"
7802
7803 "ble 1b\n"
7804
7805 // clang-format on
7806
7807 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
7808 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
7809 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
7810 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
7811 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
7812 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
7813 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
7814 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
7815 "v26", "v27", "v28", "v29", "v30", "v31");
7816 }
7817 #undef RUY_OFFSET_BIAS
7818 #undef RUY_OFFSET_FLAGS
7819 #undef RUY_OFFSET_LHS_BASE_PTR
7820 #undef RUY_OFFSET_CLAMP_MIN
7821 #undef RUY_OFFSET_CLAMP_MAX
7822 #undef RUY_OFFSET_START_ROW
7823 #undef RUY_OFFSET_LAST_ROW
7824 #undef RUY_OFFSET_LAST_COL
7825 #undef RUY_OFFSET_LHS_STRIDE
7826 #undef RUY_OFFSET_RHS_STRIDE
7827 #undef RUY_OFFSET_DST_STRIDE
7828 #undef RUY_OFFSET_DEPTH
7829 #undef RUY_OFFSET_START_COL
7830 #undef RUY_OFFSET_RHS_BASE_PTR
7831 #undef RUY_OFFSET_DST_BASE_PTR
7832
7833 #endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
7834
7835 } // namespace ruy
7836