1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // This is a standalone testbed and benchmark for gemmlowp-style GEMM kernels,
16 // either doing integer or float arithmetic.
17 // It verifies that a kernel produces correct results, then benchmarks it.
18 //
19 // Some benchmark results are recorded in this spreadsheet:
20 //
21 // https://docs.google.com/spreadsheets/d/1UPbzbp9rdsD6RXxOr5q6AZ0n1omgEknLYO2ogiw6Kqk/edit?usp=sharing
22 //
23 // This program is entirely self-contained, and can be compiled manually
24 // such as suggested in the command lines below.
25 // It currently supports only Android/ARM but would trivially generalize to
26 // other OSes (it's mostly standard POSIX) or architectures (each kernel
27 // targets a specific architecture, one may simply add more).
28
29 /*
30 Build and run this benchmark on Android/ARM/32bit:
31 ~/android/toolchains/arm-linux-androideabi/bin/arm-linux-androideabi-clang++ \
32 -fPIE -pie -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \
33 /tmp/benchmark -mfloat-abi=softfp -mfpu=neon-vfpv4 && adb push /tmp/benchmark \
34 /data/local/tmp && adb shell /data/local/tmp/benchmark
35 Build and run this benchmark on Android/ARM/64bit:
36 ~/android/toolchains/aarch64-linux-android/bin/aarch64-linux-android-clang++ \
37 -fPIE -static -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \
38 /tmp/benchmark && adb push /tmp/benchmark /data/local/tmp && adb shell \
39 /data/local/tmp/benchmark
40 */
41
42 // For big.LITTLE devices, use 'taskset' to select which cores to benchmark.
43 //
44 // The syntax is: taskset <mask> <commandline>
45 // where mask is a binary mask where each bit corresponds to a core,
46 // and low bits are little cores.
47 //
48 // Examples:
49 // Nexus 5X big cores: taskset 30
50 // Nexus 5X little cores: taskset 0f
51 // Pixel XL big cores: taskset 0c
52 // Pixel XL little cores: taskset 03
53 //
54 // Full example:
55 // adb shell taskset 0c /data/local/tmp/benchmark
56
57 #include <sched.h>
58 #include <unistd.h>
59
60 #include <algorithm>
61 #include <cassert>
62 #include <cstdint>
63 #include <cstdlib>
64 #include <cstring>
65 #include <iostream>
66 #include <random>
67 #include <type_traits>
68
69 #if !defined(__arm__) && !defined(__aarch64__) && \
70 !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa))
71 #error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections).
72 #endif
73
74 #if defined(__arm__) || defined(__aarch64__)
75 #include <arm_neon.h>
76 #endif
77
78 #if defined(__mips)
79 #include <msa.h>
80
81 // Some convenience macros to hide differences between MIPS32 and MIPS64.
82 #ifdef __LP64__
83 #define GEMMLOWP_MIPS_XADDIU "daddiu"
84 #else
85 #define GEMMLOWP_MIPS_XADDIU "addiu"
86 #endif
87 #endif
88
89 // Typically one wants to fit in L1 cache, and GEMM implementations
90 // are carefully optimized to tune their access patterns to that effect.
91 // Most devices have at least 16k of L1 cache. The Kraits have exactly 16k.
92 const int kDefaultCacheSizeK = 16;
93
94 const int kCacheLineSize = 64;
95
96 // These definitions are used for labels within assembly code. Required for
97 // iOS toolchain compatibility.
98 #define GEMMLOWP_LABEL_AFTER_LOOP "1"
99 #define GEMMLOWP_LABEL_LOOP "2"
100 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
101 #define GEMMLOWP_LABEL_STORE "4"
102
103 // BEGIN code copied from gemmlowp/internal/kernel.h
104
105 // Explanation of general gemmlowp terminology
106 // ===========================================
107 //
108 // We use the following abbreviations:
109 // LHS = "left-hand side"
110 // RHS = "right-hand side"
111 // Sometimes when referring to either LHS or RHS, we just say a "Side".
112 //
113 // In a matrix product of a MxK matrix times a KxN matrix,
114 // we call K the 'depth'. Note that M is the number of rows
115 // of the result (and of the LHS), and N is the number of columns
116 // of the result (and of the RHS).
117 //
118 // In each of the LHS and RHS matrices, we call 'width' the
119 // other dimension, besides the depth. So in the LHS, 'width'
120 // is the number of rows, while in the RHS, 'width' is the number
121 // of columns.
122 //
123 // So in the LHS MxK matrix, the depth is K and the width in M.
124 // And in the RHS KxN matrix, the depth is K and the width in N.
125 //
126 // This is illustrated in this picture:
127 //
128 // RHS width
129 // <----------------->
130 // +-----------------+ ^
131 // | RHS | | Depth
132 // +-----------------+ v
133 // ^ +--+ +-----------------+
134 // | |L | | |
135 // LHS width | |H | | Result |
136 // | |S | | |
137 // v +--+ +-----------------+
138 // <-->
139 // Depth
140
141 // Explanation of gemmlowp kernel formats and "cells"
142 // ==================================================
143 //
144 // Kernels operate on small LHS and RHS blocks that fit in registers.
145 // These blocks are stored contiguously in memory, but not always
146 // in a traditional column-major or row-major order; instead,
147 // they consist of a number of sub-blocks, which we call "cells",
148 // that are stored in column-major or row-major order. However,
149 // what really matters to us is not so much rows vs columns, but
150 // rather width vs depth. So we refer to "width-major" and "depth-major"
151 // storage orders. In the LHS, width-major means row-major,
152 // while in the RHS, width-major means column-major.
153 // There is also a third possibility, "diagonal order",
154 // which is unused at the moment.
155 //
156 // We aim to treat both sides, LHS and RHS, on an equal footing,
157 // so we call them both 'sides'. A KernelFormat thus is just a pair
158 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
159 // contains a CellFormat and a number of cells; cells are only ever
160 // stacked in the width dimension, which means stacked vertically in the
161 // LHS and stacked horizondally in the RHS.
162 //
163 // Example
164 // =======
165 //
166 // Let's work out the data layout expected by a kernel having the
167 // following format (the struct names here are defined below in this file):
168 //
169 // KernelFormat<
170 // KernelSideFormat<CellFormat<3, 4>, 3>,
171 // KernelSideFormat<CellFormat<5, 4>, 2>
172 // >
173 //
174 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
175 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in
176 // DepthMajor order (the default value, see CellFormat). In the LHS,
177 // DepthMajor means column-major, so the LHS cells are of size 3x4 in
178 // column-major order, so the LHS layout is:
179 //
180 // 0 3 6 9
181 // 1 4 7 10
182 // 2 5 8 11
183 // 12 15 18 21
184 // 13 16 19 22
185 // 14 17 20 23
186 // 24 27 30 33
187 // 25 28 31 34
188 // 26 29 32 35
189 //
190 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
191 // 2 cells each having dimensions (width=5, depth=4), laid out in
192 // DepthMajor order (the default value, see CellFormat). In the RHS,
193 // DepthMajor means row-major, so the RHS cells are of size 4x5 in
194 // row-major order, so the RHS layout is:
195 //
196 // 0 1 2 3 4 20 21 22 23 24
197 // 5 6 7 8 9 25 26 27 28 29
198 // 10 11 12 13 14 30 31 32 33 34
199 // 15 16 17 18 19 35 36 37 38 39
200
201 // CellOrder enumerates the possible storage orders (=layouts) for
202 // a cell (see explanation above).
203 enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
204
205 // CellFormat describes how data is laid
206 // out in a cell. That is, a CellOrder together with actual dimensions.
207 template <int tWidth, int tDepth, CellOrder tOrder>
208 struct CellFormat {
209 static const int kWidth = tWidth;
210 static const int kDepth = tDepth;
211 static const CellOrder kOrder = tOrder;
212
213 static const int kSize = kWidth * kDepth;
214 };
215
216 // KernelSideFormat describes how data is laid out in a kernel side
217 // (i.e. LHS or RHS). That is, a CellFormat together with a number of
218 // cells. These cells are always stacked in the Width dimension.
219 // For example, in the LHS case, the Width dimension is the rows dimension,
220 // se we're saying that in the LHS, cells are stacked vertically.
221 // We never stack cells in the Depth dimension.
222 template <typename tCellFormat, int tCells>
223 struct KernelSideFormat {
224 typedef tCellFormat Cell;
225 static const int kCells = tCells;
226 static const int kWidth = kCells * Cell::kWidth;
227 static const int kDepth = Cell::kDepth;
228 };
229
230 // KernelFormat describes fully the input data layout that a kernel expects.
231 // It consists of two KernelSideFormat's, one for LHS and one for RHS.
232 template <typename tLhs, typename tRhs>
233 struct KernelFormat {
234 typedef tLhs Lhs;
235 typedef tRhs Rhs;
236
237 static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
238 static const int kDepth = Lhs::Cell::kDepth;
239 static const int kRows = Lhs::Cell::kWidth * Lhs::kCells;
240 static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
241 };
242
CellOrderName(CellOrder o)243 inline const char* CellOrderName(CellOrder o) {
244 switch (o) {
245 case CellOrder::DepthMajor:
246 return "DepthMajor";
247 case CellOrder::WidthMajor:
248 return "WidthMajor";
249 case CellOrder::Diagonal:
250 return "Diagonal";
251 default:
252 assert(false);
253 return nullptr;
254 }
255 }
256
257 // Returns the offset into a cell, at which a given coefficient is stored.
258 template <typename CellFormat>
OffsetIntoCell(int w,int d)259 inline int OffsetIntoCell(int w, int d) {
260 switch (CellFormat::kOrder) {
261 case CellOrder::DepthMajor:
262 return w + d * CellFormat::kWidth;
263 case CellOrder::WidthMajor:
264 return d + w * CellFormat::kDepth;
265 case CellOrder::Diagonal:
266 assert(CellFormat::kWidth == CellFormat::kDepth);
267 static const int size = CellFormat::kWidth;
268 return ((size + w - d) * size + d) % (size * size);
269 default:
270 assert(false);
271 return 0;
272 }
273 }
274
275 // END code copied from gemmlowp/internal/kernel.h
276
277 #ifdef __arm__
278
279 // This is the current standard kernel in gemmlowp, see:
280 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L33
281 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators {
282 typedef std::uint8_t OperandType;
283 typedef std::uint32_t AccumulatorType;
284 typedef KernelFormat<
285 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
286 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
287 Format;
RunNEON_32bit_GEMM_Uint8Operands_Uint32Accumulators288 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
289 AccumulatorType* accum_ptr, int depth) {
290 asm volatile(
291 // Load 1 Rhs cell of size 2x4
292 "vld1.8 {d0}, [%[rhs_ptr]]!\n"
293 // Load 3 Lhs cells of size 4x2 each
294 "vld1.8 {d2}, [%[lhs_ptr]]!\n"
295 "vld1.8 {d4}, [%[lhs_ptr]]!\n"
296 "vld1.8 {d6}, [%[lhs_ptr]]!\n"
297 // Load accumulators
298 "mov r0, %[accum_ptr]\n"
299 "vld1.32 {d8, d9}, [r0]!\n"
300 "vld1.32 {d16, d17}, [r0]!\n"
301 "vld1.32 {d24, d25}, [r0]!\n"
302 "vld1.32 {d10, d11}, [r0]!\n"
303 "vld1.32 {d18, d19}, [r0]!\n"
304 "vld1.32 {d26, d27}, [r0]!\n"
305 "vld1.32 {d12, d13}, [r0]!\n"
306 "vld1.32 {d20, d21}, [r0]!\n"
307 "vld1.32 {d28, d29}, [r0]!\n"
308 "vld1.32 {d14, d15}, [r0]!\n"
309 "vld1.32 {d22, d23}, [r0]!\n"
310 "vld1.32 {d30, d31}, [r0]!\n"
311
312 "subs %[depth], #2\n"
313
314 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
315
316 GEMMLOWP_LABEL_LOOP
317 ":\n"
318 // Overview of register layout:
319 //
320 // A 2x4 cell of Rhs is stored in 16bit in d0--d1 (q0).
321 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in d2--d7
322 // (q1--q3).
323 // A 12x4 block of accumulators is stored in 32bit in q4--q15.
324 //
325 // +-----+-----+-----+-----+
326 // |d0[0]|d0[1]|d0[2]|d0[3]|
327 // Rhs +-----+-----+-----+-----+
328 // |d1[0]|d1[1]|d1[2]|d1[3]|
329 // +-----+-----+-----+-----+
330 //
331 // | | | | |
332 //
333 // Lhs | | | | |
334 //
335 // +--+--+ - - - - +-----+-----+-----+-----+
336 // |d2|d3| | q4 | q5 | q6 | q7 |
337 // |d2|d3| | q4 | q5 | q6 | q7 |
338 // |d2|d3| | q4 | q5 | q6 | q7 |
339 // |d2|d3| | q4 | q5 | q6 | q7 |
340 // +--+--+ - - - - +-----+-----+-----+-----+
341 // |d4|d5| | q8 | q9 | q10 | q11 |
342 // |d4|d5| | q8 | q9 | q10 | q11 |
343 // |d4|d5| | q8 | q9 | q10 | q11 |
344 // |d4|d5| | q8 | q9 | q10 | q11 |
345 // +--+--+ - - - - +-----+-----+-----+-----+
346 // |d6|d7| | q12 | q13 | q14 | q15 |
347 // |d6|d7| | q12 | q13 | q14 | q15 |
348 // |d6|d7| | q12 | q13 | q14 | q15 |
349 // |d6|d7| | q12 | q13 | q14 | q15 |
350 // +--+--+ - - - - +-----+-----+-----+-----+
351 //
352 // Accumulator
353
354 // Expand Lhs/Rhs cells to 16 bit.
355 // Note: moving theses vmovls further down to allow for
356 // longer data pipelining helps a little on A57 but is
357 // harmful on A53 --- It looks as if A53 doesn't like
358 // interleaving vmovl's into the vmlal's.
359 "vmovl.u8 q0, d0\n"
360 "vmovl.u8 q1, d2\n"
361 "vmovl.u8 q2, d4\n"
362 "vmovl.u8 q3, d6\n"
363
364 // Multiply-accumulate, level of depth 0
365 "vmlal.u16 q4, d2, d0[0]\n"
366 "vmlal.u16 q5, d2, d0[1]\n"
367 "vmlal.u16 q6, d2, d0[2]\n"
368 "vmlal.u16 q7, d2, d0[3]\n"
369 "vldr d2, [%[lhs_ptr]]\n"
370 "vmlal.u16 q8, d4, d0[0]\n"
371 "vmlal.u16 q9, d4, d0[1]\n"
372 "vmlal.u16 q10, d4, d0[2]\n"
373 "vmlal.u16 q11, d4, d0[3]\n"
374 "vldr d4, [%[lhs_ptr], #8]\n"
375 "vmlal.u16 q12, d6, d0[0]\n"
376 "vmlal.u16 q13, d6, d0[1]\n"
377 "vmlal.u16 q14, d6, d0[2]\n"
378 "vmlal.u16 q15, d6, d0[3]\n"
379 "vldr d6, [%[lhs_ptr], #16]\n"
380 "vldr d0, [%[rhs_ptr]]\n"
381
382 // Multiply-accumulate, level of depth 1
383 "vmlal.u16 q4, d3, d1[0]\n"
384 "vmlal.u16 q5, d3, d1[1]\n"
385 "add %[lhs_ptr], #24\n"
386 "vmlal.u16 q6, d3, d1[2]\n"
387 "vmlal.u16 q7, d3, d1[3]\n"
388 "add %[rhs_ptr], #8\n"
389 "vmlal.u16 q8, d5, d1[0]\n"
390 "vmlal.u16 q9, d5, d1[1]\n"
391 "subs %[depth], #2\n"
392 "vmlal.u16 q10, d5, d1[2]\n"
393 "vmlal.u16 q11, d5, d1[3]\n"
394 "vmlal.u16 q12, d7, d1[0]\n"
395 "vmlal.u16 q13, d7, d1[1]\n"
396 "vmlal.u16 q14, d7, d1[2]\n"
397 "vmlal.u16 q15, d7, d1[3]\n"
398
399 "bne " GEMMLOWP_LABEL_LOOP "b\n"
400
401 GEMMLOWP_LABEL_AFTER_LOOP
402 ":\n"
403
404 // Expand Lhs/Rhs cells to 16 bit.
405 "vmovl.u8 q0, d0\n"
406 "vmovl.u8 q1, d2\n"
407 "vmovl.u8 q2, d4\n"
408 "vmovl.u8 q3, d6\n"
409
410 // Multiply-accumulate, level of depth 0
411 "vmlal.u16 q4, d2, d0[0]\n"
412 "vmlal.u16 q5, d2, d0[1]\n"
413 "vmlal.u16 q6, d2, d0[2]\n"
414 "vmlal.u16 q7, d2, d0[3]\n"
415 "vmlal.u16 q8, d4, d0[0]\n"
416 "vmlal.u16 q9, d4, d0[1]\n"
417 "vmlal.u16 q10, d4, d0[2]\n"
418 "vmlal.u16 q11, d4, d0[3]\n"
419 "vmlal.u16 q12, d6, d0[0]\n"
420 "vmlal.u16 q13, d6, d0[1]\n"
421 "vmlal.u16 q14, d6, d0[2]\n"
422 "vmlal.u16 q15, d6, d0[3]\n"
423
424 // Multiply-accumulate, level of depth 1
425 "vmlal.u16 q4, d3, d1[0]\n"
426 "vmlal.u16 q5, d3, d1[1]\n"
427 "vmlal.u16 q6, d3, d1[2]\n"
428 "vmlal.u16 q7, d3, d1[3]\n"
429 "vmlal.u16 q8, d5, d1[0]\n"
430 "vmlal.u16 q9, d5, d1[1]\n"
431 "vmlal.u16 q10, d5, d1[2]\n"
432 "vmlal.u16 q11, d5, d1[3]\n"
433 "vmlal.u16 q12, d7, d1[0]\n"
434 "vmlal.u16 q13, d7, d1[1]\n"
435 "vmlal.u16 q14, d7, d1[2]\n"
436 "vmlal.u16 q15, d7, d1[3]\n"
437
438 // Store accumulators
439 "mov r0, %[accum_ptr]\n"
440 "vst1.32 {d8, d9}, [r0]!\n"
441 "vst1.32 {d16, d17}, [r0]!\n"
442 "vst1.32 {d24, d25}, [r0]!\n"
443 "vst1.32 {d10, d11}, [r0]!\n"
444 "vst1.32 {d18, d19}, [r0]!\n"
445 "vst1.32 {d26, d27}, [r0]!\n"
446 "vst1.32 {d12, d13}, [r0]!\n"
447 "vst1.32 {d20, d21}, [r0]!\n"
448 "vst1.32 {d28, d29}, [r0]!\n"
449 "vst1.32 {d14, d15}, [r0]!\n"
450 "vst1.32 {d22, d23}, [r0]!\n"
451 "vst1.32 {d30, d31}, [r0]!\n"
452 : // outputs
453 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
454 [depth] "+r"(depth)
455 : // inputs
456 [accum_ptr] "r"(accum_ptr)
457 : // clobbers
458 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
459 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
460 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
461 "d28", "d29", "d30", "d31");
462 }
463 };
464
465 // This is Maciek Chociej's fast kernel not expanding operands,
466 // from gemmlowp/meta/. Search for
467 // mul_3x8_3x8_int32_lhsadd_rhsadd
468 // in this file:
469 // https://raw.githubusercontent.com/google/gemmlowp/e4b9d858b6637d5d0058bfa3d869d2b95864251b/meta/single_thread_gemm.h
470 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand {
471 typedef std::uint8_t OperandType;
472 typedef std::uint32_t AccumulatorType;
473 typedef KernelFormat<
474 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>,
475 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1> >
476 Format;
RunNEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand477 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
478 AccumulatorType* accum_ptr, int depth) {
479 asm volatile(
480 // Clear aggregators.
481 "vmov.i32 q0, #0\n"
482 "vmov.i32 q1, #0\n"
483 "vmov.i32 q2, #0\n"
484 "vmov.i32 q3, q0\n"
485 "vmov.i32 q4, q1\n"
486 "vmov.i32 q5, q2\n"
487 "vmov.i32 q6, q3\n"
488 "vmov.i32 q7, q4\n"
489 "vmov.i32 q8, q5\n"
490
491 // Loop head
492 GEMMLOWP_LABEL_LOOP
493 ":\n"
494
495 // Subtract counter.
496 "subs %[depth], %[depth], #8\n"
497
498 "vld1.8 {d18, d19, d20}, [%[rhs_ptr]]!\n"
499 "vld1.8 {d21, d22, d23}, [%[lhs_ptr]]!\n"
500 "vmull.u8 q12, d18, d21\n"
501 "vmull.u8 q13, d18, d22\n"
502 "vmull.u8 q14, d18, d23\n"
503 "vmull.u8 q15, d19, d21\n"
504 "vpadal.u16 q0, q12\n"
505 "vpadal.u16 q1, q13\n"
506 "vpadal.u16 q2, q14\n"
507 "vpadal.u16 q3, q15\n"
508 "vmull.u8 q12, d19, d22\n"
509 "vmull.u8 q13, d19, d23\n"
510 "vmull.u8 q14, d20, d21\n"
511 "vmull.u8 q15, d20, d22\n"
512 "vmull.u8 q9, d20, d23\n"
513 "vpadal.u16 q4, q12\n"
514 "vpadal.u16 q5, q13\n"
515 "vpadal.u16 q6, q14\n"
516 "vpadal.u16 q7, q15\n"
517 "vpadal.u16 q8, q9\n"
518
519 // Loop branch
520 "bne " GEMMLOWP_LABEL_LOOP
521 "b\n"
522
523 // Horizontal reduce aggregators, step 1
524 "vpadd.u32 d0, d0, d1\n"
525 "vpadd.u32 d2, d2, d3\n"
526 "vpadd.u32 d4, d4, d5\n"
527 "vpadd.u32 d6, d6, d7\n"
528 "vpadd.u32 d8, d8, d9\n"
529 "vpadd.u32 d10, d10, d11\n"
530 "vpadd.u32 d12, d12, d13\n"
531 "vpadd.u32 d14, d14, d15\n"
532 "vpadd.u32 d16, d16, d17\n"
533
534 // Horizontal reduce aggregators, step 2
535 "vpadd.u32 d0, d0, d2\n"
536 "vpadd.u32 d1, d4, d4\n"
537 "vpadd.u32 d6, d6, d8\n"
538 "vpadd.u32 d7, d10, d10\n"
539 "vpadd.u32 d12, d12, d14\n"
540 "vpadd.u32 d13, d16, d16\n"
541
542 // Load accumulators
543 "mov r0, %[accum_ptr]\n"
544 "vld1.32 {d2}, [r0]!\n"
545 "vld1.32 {d3[0]}, [r0]!\n"
546
547 "vld1.32 {d8}, [r0]!\n"
548 "vld1.32 {d9[0]}, [r0]!\n"
549
550 "vld1.32 {d14}, [r0]!\n"
551 "vld1.32 {d15[0]}, [r0]!\n"
552
553 // Accumulate
554 "vadd.s32 q0, q0, q1\n"
555 "vadd.s32 q3, q3, q4\n"
556 "vadd.s32 q6, q6, q7\n"
557
558 // Store accumulators
559 "mov r0, %[accum_ptr]\n"
560 "vst1.32 {d0}, [r0]!\n"
561 "vst1.32 {d1[0]}, [r0]!\n"
562
563 "vst1.32 {d6}, [r0]!\n"
564 "vst1.32 {d7[0]}, [r0]!\n"
565
566 "vst1.32 {d12}, [r0]!\n"
567 "vst1.32 {d13[0]}, [r0]!\n"
568 : // outputs
569 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
570 [depth] "+r"(depth)
571 : // inputs
572 [accum_ptr] "r"(accum_ptr)
573 : // clobbers
574 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
575 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
576 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
577 "d28", "d29", "d30", "d31");
578 }
579 };
580
581 // Fast kernel operating on int8 operands.
582 // It is assumed that one of the two int8 operands only takes values
583 // in [-127, 127], while the other may freely range in [-128, 127].
584 // The issue with both operands taking the value -128 is that:
585 // -128*-128 + -128*-128 == -32768 overflows int16.
586 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
587 // range. That is the basic idea of this kernel.
588 struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
589 typedef std::int8_t OperandType;
590 typedef std::int32_t AccumulatorType;
591 typedef KernelFormat<
592 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
593 KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
594 Format;
RunNEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits595 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
596 AccumulatorType* accum_ptr, int depth) {
597 std::size_t start_depth = 123;
598 std::size_t run_depth = depth;
599 std::size_t dst_col_stride = 4;
600 AccumulatorType* dst_ptr = accum_ptr;
601 asm volatile(
602
603 // Overview of register layout:
604 //
605 // A 2x16 block of Rhs is stored in 8 bit in d0--d3.
606 // A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only
607 // half of the register space required, so we loop over these registers
608 // twice. Only half of it, a 2x16 block, is stored in d4--d7 at
609 // any given time.
610 //
611 // A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit
612 // components which need to be horizontally-added at the end)
613 //
614 // The Lhs vectors are multiplied by the Rhs vectors with a widening
615 // multiply over the 8 first levels of depth, producing int16x8
616 // vectors of products for each position in the accumulator matrix.
617 // Here comes the special trick: since the operands are signed int8,
618 // their range being [ -2^7 , 2^7 ), their products are in range
619 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
620 // without any risk of overflowing int16.
621 // We thus proceed with the 8 next levels of depth, multiplying
622 // again Lhs by Rhs, accumulating into this existing int16x8 vector.
623 //
624 // Only then, having processed 16 levels of depth, do we need to
625 // horizontally add these int16x8 accumulators into the final
626 // int32x4 accumulators.
627 //
628 // As we do not have enough registers to store all 16 int16x8
629 // temporary-16bit-accumulators, we have them cycle through q4--q7.
630 //
631 //
632 // Register layout (ignoring the q4--q7 temporary 16bit accumulators):
633 //
634 // +----+----+
635 // | d0 | d2 |
636 // | . | . |
637 // | . | . |
638 // | . | . |
639 // Rhs +----+----+
640 // | d1 | d3 |
641 // | . | . |
642 // | . | . |
643 // | . | . |
644 // +----+----+
645 //
646 // | | |
647 //
648 // Lhs | | |
649 //
650 // +--------+--------+ - - - - +----+----+
651 // | d4 ... | d5 ... | | q8 | q9 |
652 // | d6 ... | d7 ... | | q10| q11|
653 // | d4 ... | d5 ... | | q12| q13|
654 // | d6 ... | d7 ... | | q14| q15|
655 // +--------+--------+ - - - - +----+----+
656 //
657 // Accumulator
658 //
659
660 // Clear accumulators, and, interleaved with it,
661 // initial loads of the first loop iteration,
662 // taken out of the loop so that in the loop itself we have
663 // optimal streaming of data from memory.
664 "vldr d0, [%[rhs_ptr], #0]\n"
665 "vmov.i32 q8, #0\n"
666 "vldr d4, [%[lhs_ptr], #0]\n"
667 "vmov.i32 q9, #0\n"
668 "vldr d2, [%[rhs_ptr], #16]\n"
669 "vmov.i32 q10, q8\n"
670 "vldr d6, [%[lhs_ptr], #16]\n"
671 "vmov.i32 q11, q8\n"
672 "vldr d1, [%[rhs_ptr], #8]\n"
673 "vmov.i32 q12, q8\n"
674 "vldr d5, [%[lhs_ptr], #8]\n"
675 "vmov.i32 q13, q8\n"
676 "vldr d3, [%[rhs_ptr], #24]\n"
677 "vmov.i32 q14, q8\n"
678 "vldr d7, [%[lhs_ptr], #24]\n"
679 "vmov.i32 q15, q8\n"
680
681 // General loop.
682 GEMMLOWP_LABEL_LOOP
683 ":\n"
684
685 // Multiply 8 first levels of depth.
686 "vmull.s8 q4, d0, d4\n"
687 "add %[rhs_ptr], %[rhs_ptr], #32\n"
688 "vmull.s8 q5, d2, d4\n"
689 "vldr d4, [%[lhs_ptr], #32]\n"
690 "vmull.s8 q6, d0, d6\n"
691 "vmull.s8 q7, d2, d6\n"
692 "vldr d6, [%[lhs_ptr], #48]\n"
693
694 // Multiply-accumulate second-half, again into the same
695 // 16bit local accumulator registers. This is where we
696 // take advantage of having int8 instead of uint8 and therefore
697 // being able to accumulate two products into int16.
698 "vmlal.s8 q4, d1, d5\n"
699 "vmlal.s8 q5, d3, d5\n"
700 "vldr d5, [%[lhs_ptr], #40]\n"
701 "vmlal.s8 q6, d1, d7\n"
702 "vmlal.s8 q7, d3, d7\n"
703 "vldr d7, [%[lhs_ptr], #56]\n"
704
705 // Add pairwise, accumulate into 32-bit accumulators.
706 "vpadal.s16 q8, q4\n"
707 "add %[lhs_ptr], %[lhs_ptr], #64\n"
708 "vpadal.s16 q9, q5\n"
709 "subs %[run_depth], %[run_depth], #16\n"
710 "vpadal.s16 q10, q6\n"
711 "vpadal.s16 q11, q7\n"
712
713 "beq " GEMMLOWP_LABEL_AFTER_LOOP
714 "f\n"
715
716 // Multiply first half.
717 "vmull.s8 q4, d0, d4\n"
718 "vmull.s8 q5, d2, d4\n"
719 "vldr d4, [%[lhs_ptr], #0]\n"
720 "vmull.s8 q6, d0, d6\n"
721 "vldr d0, [%[rhs_ptr], #0]\n"
722 "vmull.s8 q7, d2, d6\n"
723 "vldr d2, [%[rhs_ptr], #16]\n"
724
725 // Multiply-accumulate second-half, again into the same
726 // 16bit local accumulator registers. This is where we
727 // take advantage of having int8 instead of uint8 and therefore
728 // being able to accumulate two products into int16.
729 "vmlal.s8 q4, d1, d5\n"
730 "vldr d6, [%[lhs_ptr], #16]\n"
731 "vmlal.s8 q5, d3, d5\n"
732 "vldr d5, [%[lhs_ptr], #8]\n"
733 "vmlal.s8 q6, d1, d7\n"
734 "vldr d1, [%[rhs_ptr], #8]\n"
735 "vmlal.s8 q7, d3, d7\n"
736 "vldr d3, [%[rhs_ptr], #24]\n"
737
738 // Add pairwise, accumulate into 32-bit accumulators.
739 "vpadal.s16 q12, q4\n"
740 "vldr d7, [%[lhs_ptr], #24]\n"
741 "vpadal.s16 q13, q5\n"
742 "vpadal.s16 q14, q6\n"
743 "vpadal.s16 q15, q7\n"
744
745 "b " GEMMLOWP_LABEL_LOOP "b\n"
746
747 GEMMLOWP_LABEL_AFTER_LOOP
748 ":\n"
749
750 // Multiply first half.
751 "vmull.s8 q4, d0, d4\n"
752 "vmull.s8 q5, d2, d4\n"
753 "vmull.s8 q6, d0, d6\n"
754 "vmull.s8 q7, d2, d6\n"
755
756 // Multiply-accumulate second-half, again into the same
757 // 16bit local accumulator registers. This is where we
758 // take advantage of having int8 instead of uint8 and therefore
759 // being able to accumulate two products into int16.
760 "vmlal.s8 q4, d1, d5\n"
761 "vmlal.s8 q5, d3, d5\n"
762 "vmlal.s8 q6, d1, d7\n"
763 "vmlal.s8 q7, d3, d7\n"
764
765 // Add pairwise, accumulate into 32-bit accumulators.
766 "vpadal.s16 q12, q4\n"
767 "vpadal.s16 q13, q5\n"
768 "vpadal.s16 q14, q6\n"
769 "vpadal.s16 q15, q7\n"
770 "cmp %[start_depth], #0\n"
771
772 // Reduce 32bit accumulators horizontally.
773 "vpadd.s32 d0, d16, d17\n"
774 "vpadd.s32 d1, d18, d19\n"
775 "vpadd.s32 d2, d20, d21\n"
776 "vpadd.s32 d3, d22, d23\n"
777 "vpadd.s32 d4, d24, d25\n"
778 "vpadd.s32 d5, d26, d27\n"
779 "vpadd.s32 d6, d28, d29\n"
780 "vpadd.s32 d7, d30, d31\n"
781
782 "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
783 "f\n"
784
785 // Reduce 32bit accumulators horizontally, second pass
786 // (each pass adds pairwise. we need to add 4-wise).
787 "vpadd.s32 d8, d0, d2\n"
788 "vpadd.s32 d9, d4, d6\n"
789 "vpadd.s32 d10, d1, d3\n"
790 "vpadd.s32 d11, d5, d7\n"
791
792 "b " GEMMLOWP_LABEL_STORE "f\n"
793
794 GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
795 ":\n"
796
797 // Reduce 32bit accumulators horizontally, second pass
798 // (each pass adds pairwise. we need to add 4-wise),
799 // and load destination values from memory.
800 "mov r0, %[dst_ptr]\n"
801 "vld1.32 {d16, d17}, [r0]!\n"
802 "vpadd.s32 d8, d0, d2\n"
803 "vpadd.s32 d9, d4, d6\n"
804 "vld1.32 {d18, d19}, [r0]\n"
805 "vpadd.s32 d10, d1, d3\n"
806 "vpadd.s32 d11, d5, d7\n"
807
808 // Add horizontally-reduced accumulators into
809 // the values loaded from memory
810 "vadd.s32 q4, q8, q4\n"
811 "vadd.s32 q5, q9, q5\n"
812
813 GEMMLOWP_LABEL_STORE
814 ":\n"
815 // Store back into memory
816 "mov r0, %[dst_ptr]\n"
817 "vst1.32 {d8, d9}, [r0]!\n"
818 "vst1.32 {d10, d11}, [r0]\n"
819 : // outputs
820 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
821 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth)
822 : // inputs
823 [start_depth] "r"(start_depth)
824 : // clobbers
825 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
826 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
827 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
828 "d28", "d29", "d30", "d31");
829 }
830 };
831
832 // We don't actually use int32*int32 in production. This is just an
833 // experiment to help dissociate the effect of integer-vs-float, from the
834 // effect of operands width.
835 struct NEON_32bit_GEMM_Int32_WithScalar {
836 typedef std::int32_t OperandType;
837 typedef std::int32_t AccumulatorType;
838 typedef KernelFormat<
839 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
840 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
841 Format;
RunNEON_32bit_GEMM_Int32_WithScalar842 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
843 AccumulatorType* accum_ptr, int depth) {
844 asm volatile(
845 // Load accumulators
846 "mov r0, %[accum_ptr]\n"
847 "vld1.32 {d8, d9}, [r0]!\n"
848 "vld1.32 {d16, d17}, [r0]!\n"
849 "vld1.32 {d24, d25}, [r0]!\n"
850 "vld1.32 {d10, d11}, [r0]!\n"
851 "vld1.32 {d18, d19}, [r0]!\n"
852 "vld1.32 {d26, d27}, [r0]!\n"
853 "vld1.32 {d12, d13}, [r0]!\n"
854 "vld1.32 {d20, d21}, [r0]!\n"
855 "vld1.32 {d28, d29}, [r0]!\n"
856 "vld1.32 {d14, d15}, [r0]!\n"
857 "vld1.32 {d22, d23}, [r0]!\n"
858 "vld1.32 {d30, d31}, [r0]!\n"
859
860 GEMMLOWP_LABEL_LOOP
861 ":\n"
862
863 // Load 1 Rhs cell of size 1x4
864 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
865
866 // Load 3 Lhs cells of size 4x1 each
867 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
868 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
869 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
870
871 // Multiply-accumulate
872 "vmla.s32 q4, q1, d0[0]\n"
873 "vmla.s32 q5, q1, d0[1]\n"
874 "vmla.s32 q6, q1, d1[0]\n"
875 "vmla.s32 q7, q1, d1[1]\n"
876 "vmla.s32 q8, q2, d0[0]\n"
877 "vmla.s32 q9, q2, d0[1]\n"
878 "vmla.s32 q10, q2, d1[0]\n"
879 "vmla.s32 q11, q2, d1[1]\n"
880 "vmla.s32 q12, q3, d0[0]\n"
881 "vmla.s32 q13, q3, d0[1]\n"
882 "vmla.s32 q14, q3, d1[0]\n"
883 "vmla.s32 q15, q3, d1[1]\n"
884
885 // Loop. Decrement loop index (depth) by 1, since we just handled 1
886 // level of depth.
887 "subs %[depth], #1\n"
888 "bne " GEMMLOWP_LABEL_LOOP
889 "b\n"
890
891 // Store accumulators
892 "mov r0, %[accum_ptr]\n"
893 "vst1.32 {d8, d9}, [r0]!\n"
894 "vst1.32 {d16, d17}, [r0]!\n"
895 "vst1.32 {d24, d25}, [r0]!\n"
896 "vst1.32 {d10, d11}, [r0]!\n"
897 "vst1.32 {d18, d19}, [r0]!\n"
898 "vst1.32 {d26, d27}, [r0]!\n"
899 "vst1.32 {d12, d13}, [r0]!\n"
900 "vst1.32 {d20, d21}, [r0]!\n"
901 "vst1.32 {d28, d29}, [r0]!\n"
902 "vst1.32 {d14, d15}, [r0]!\n"
903 "vst1.32 {d22, d23}, [r0]!\n"
904 "vst1.32 {d30, d31}, [r0]!\n"
905 : // outputs
906 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
907 [depth] "+r"(depth)
908 : // inputs
909 [accum_ptr] "r"(accum_ptr)
910 : // clobbers
911 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
912 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
913 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
914 "d28", "d29", "d30", "d31");
915 }
916 };
917
918 // Not very efficient kernel, just an experiment to see what we can do
919 // without using NEON multiply-with-scalar instructions.
920 struct NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar {
921 typedef float OperandType;
922 typedef float AccumulatorType;
923 typedef KernelFormat<
924 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
925 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
926 Format;
RunNEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar927 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
928 AccumulatorType* accum_ptr, int depth) {
929 asm volatile(
930 // Load accumulators
931 "mov r0, %[accum_ptr]\n"
932 "vld1.32 {d8, d9}, [r0]!\n"
933 "vld1.32 {d16, d17}, [r0]!\n"
934 "vld1.32 {d24, d25}, [r0]!\n"
935 "vld1.32 {d10, d11}, [r0]!\n"
936 "vld1.32 {d18, d19}, [r0]!\n"
937 "vld1.32 {d26, d27}, [r0]!\n"
938 "vld1.32 {d12, d13}, [r0]!\n"
939 "vld1.32 {d20, d21}, [r0]!\n"
940 "vld1.32 {d28, d29}, [r0]!\n"
941 "vld1.32 {d14, d15}, [r0]!\n"
942 "vld1.32 {d22, d23}, [r0]!\n"
943 "vld1.32 {d30, d31}, [r0]!\n"
944
945 GEMMLOWP_LABEL_LOOP
946 ":\n"
947
948 // Load 3 Lhs cells of size 4x1 each
949 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
950 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
951 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
952
953 // Multiply-accumulate
954 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
955 "vmla.f32 q4, q1, q0\n"
956 "vmla.f32 q8, q2, q0\n"
957 "vmla.f32 q12, q3, q0\n"
958 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
959 "vmla.f32 q5, q1, q0\n"
960 "vmla.f32 q9, q2, q0\n"
961 "vmla.f32 q13, q3, q0\n"
962 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
963 "vmla.f32 q6, q1, q0\n"
964 "vmla.f32 q10, q2, q0\n"
965 "vmla.f32 q14, q3, q0\n"
966 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
967 "vmla.f32 q7, q1, q0\n"
968 "vmla.f32 q11, q2, q0\n"
969 "vmla.f32 q15, q3, q0\n"
970
971 // Loop. Decrement loop index (depth) by 1, since we just handled 1
972 // level of depth.
973 "subs %[depth], #1\n"
974 "bne " GEMMLOWP_LABEL_LOOP
975 "b\n"
976
977 // Store accumulators
978 "mov r0, %[accum_ptr]\n"
979 "vst1.32 {d8, d9}, [r0]!\n"
980 "vst1.32 {d16, d17}, [r0]!\n"
981 "vst1.32 {d24, d25}, [r0]!\n"
982 "vst1.32 {d10, d11}, [r0]!\n"
983 "vst1.32 {d18, d19}, [r0]!\n"
984 "vst1.32 {d26, d27}, [r0]!\n"
985 "vst1.32 {d12, d13}, [r0]!\n"
986 "vst1.32 {d20, d21}, [r0]!\n"
987 "vst1.32 {d28, d29}, [r0]!\n"
988 "vst1.32 {d14, d15}, [r0]!\n"
989 "vst1.32 {d22, d23}, [r0]!\n"
990 "vst1.32 {d30, d31}, [r0]!\n"
991 : // outputs
992 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
993 [depth] "+r"(depth)
994 : // inputs
995 [accum_ptr] "r"(accum_ptr)
996 : // clobbers
997 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
998 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
999 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1000 "d28", "d29", "d30", "d31");
1001 }
1002 };
1003
1004 // Not very efficient kernel, just an experiment to see what we can do
1005 // without using NEON multiply-with-scalar instructions.
1006 // This variant is relevant as on ARMv7 FMA does not have a with-scalar variant.
1007 struct NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar {
1008 typedef float OperandType;
1009 typedef float AccumulatorType;
1010 typedef KernelFormat<
1011 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1012 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1013 Format;
RunNEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar1014 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1015 AccumulatorType* accum_ptr, int depth) {
1016 asm volatile(
1017 // Load accumulators
1018 "mov r0, %[accum_ptr]\n"
1019 "vld1.32 {d8, d9}, [r0]!\n"
1020 "vld1.32 {d16, d17}, [r0]!\n"
1021 "vld1.32 {d24, d25}, [r0]!\n"
1022 "vld1.32 {d10, d11}, [r0]!\n"
1023 "vld1.32 {d18, d19}, [r0]!\n"
1024 "vld1.32 {d26, d27}, [r0]!\n"
1025 "vld1.32 {d12, d13}, [r0]!\n"
1026 "vld1.32 {d20, d21}, [r0]!\n"
1027 "vld1.32 {d28, d29}, [r0]!\n"
1028 "vld1.32 {d14, d15}, [r0]!\n"
1029 "vld1.32 {d22, d23}, [r0]!\n"
1030 "vld1.32 {d30, d31}, [r0]!\n"
1031
1032 GEMMLOWP_LABEL_LOOP
1033 ":\n"
1034
1035 // Load 3 Lhs cells of size 4x1 each
1036 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1037 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1038 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1039
1040 // Multiply-accumulate
1041 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1042 "vfma.f32 q4, q1, q0\n"
1043 "vfma.f32 q8, q2, q0\n"
1044 "vfma.f32 q12, q3, q0\n"
1045 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1046 "vfma.f32 q5, q1, q0\n"
1047 "vfma.f32 q9, q2, q0\n"
1048 "vfma.f32 q13, q3, q0\n"
1049 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1050 "vfma.f32 q6, q1, q0\n"
1051 "vfma.f32 q10, q2, q0\n"
1052 "vfma.f32 q14, q3, q0\n"
1053 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n"
1054 "vfma.f32 q7, q1, q0\n"
1055 "vfma.f32 q11, q2, q0\n"
1056 "vfma.f32 q15, q3, q0\n"
1057
1058 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1059 // level of depth.
1060 "subs %[depth], #1\n"
1061 "bne " GEMMLOWP_LABEL_LOOP
1062 "b\n"
1063
1064 // Store accumulators
1065 "mov r0, %[accum_ptr]\n"
1066 "vst1.32 {d8, d9}, [r0]!\n"
1067 "vst1.32 {d16, d17}, [r0]!\n"
1068 "vst1.32 {d24, d25}, [r0]!\n"
1069 "vst1.32 {d10, d11}, [r0]!\n"
1070 "vst1.32 {d18, d19}, [r0]!\n"
1071 "vst1.32 {d26, d27}, [r0]!\n"
1072 "vst1.32 {d12, d13}, [r0]!\n"
1073 "vst1.32 {d20, d21}, [r0]!\n"
1074 "vst1.32 {d28, d29}, [r0]!\n"
1075 "vst1.32 {d14, d15}, [r0]!\n"
1076 "vst1.32 {d22, d23}, [r0]!\n"
1077 "vst1.32 {d30, d31}, [r0]!\n"
1078 : // outputs
1079 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1080 [depth] "+r"(depth)
1081 : // inputs
1082 [accum_ptr] "r"(accum_ptr)
1083 : // clobbers
1084 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1085 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1086 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1087 "d28", "d29", "d30", "d31");
1088 }
1089 };
1090
1091 // This is the "most natural" kernel, using NEON multiply-with-scalar
1092 // instructions.
1093 struct NEON_32bit_GEMM_Float32_MLA_WithScalar {
1094 typedef float OperandType;
1095 typedef float AccumulatorType;
1096 typedef KernelFormat<
1097 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1098 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1099 Format;
RunNEON_32bit_GEMM_Float32_MLA_WithScalar1100 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1101 AccumulatorType* accum_ptr, int depth) {
1102 asm volatile(
1103 // Load accumulators
1104 "mov r0, %[accum_ptr]\n"
1105 "vld1.32 {d8, d9}, [r0]!\n"
1106 "vld1.32 {d16, d17}, [r0]!\n"
1107 "vld1.32 {d24, d25}, [r0]!\n"
1108 "vld1.32 {d10, d11}, [r0]!\n"
1109 "vld1.32 {d18, d19}, [r0]!\n"
1110 "vld1.32 {d26, d27}, [r0]!\n"
1111 "vld1.32 {d12, d13}, [r0]!\n"
1112 "vld1.32 {d20, d21}, [r0]!\n"
1113 "vld1.32 {d28, d29}, [r0]!\n"
1114 "vld1.32 {d14, d15}, [r0]!\n"
1115 "vld1.32 {d22, d23}, [r0]!\n"
1116 "vld1.32 {d30, d31}, [r0]!\n"
1117
1118 GEMMLOWP_LABEL_LOOP
1119 ":\n"
1120
1121 // Load 1 Rhs cell of size 1x4
1122 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1123
1124 // Load 3 Lhs cells of size 4x1 each
1125 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1126 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1127 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1128
1129 // Multiply-accumulate
1130 "vmla.f32 q4, q1, d0[0]\n"
1131 "vmla.f32 q5, q1, d0[1]\n"
1132 "vmla.f32 q6, q1, d1[0]\n"
1133 "vmla.f32 q7, q1, d1[1]\n"
1134 "vmla.f32 q8, q2, d0[0]\n"
1135 "vmla.f32 q9, q2, d0[1]\n"
1136 "vmla.f32 q10, q2, d1[0]\n"
1137 "vmla.f32 q11, q2, d1[1]\n"
1138 "vmla.f32 q12, q3, d0[0]\n"
1139 "vmla.f32 q13, q3, d0[1]\n"
1140 "vmla.f32 q14, q3, d1[0]\n"
1141 "vmla.f32 q15, q3, d1[1]\n"
1142
1143 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1144 // level of depth.
1145 "subs %[depth], #1\n"
1146 "bne " GEMMLOWP_LABEL_LOOP
1147 "b\n"
1148
1149 // Store accumulators
1150 "mov r0, %[accum_ptr]\n"
1151 "vst1.32 {d8, d9}, [r0]!\n"
1152 "vst1.32 {d16, d17}, [r0]!\n"
1153 "vst1.32 {d24, d25}, [r0]!\n"
1154 "vst1.32 {d10, d11}, [r0]!\n"
1155 "vst1.32 {d18, d19}, [r0]!\n"
1156 "vst1.32 {d26, d27}, [r0]!\n"
1157 "vst1.32 {d12, d13}, [r0]!\n"
1158 "vst1.32 {d20, d21}, [r0]!\n"
1159 "vst1.32 {d28, d29}, [r0]!\n"
1160 "vst1.32 {d14, d15}, [r0]!\n"
1161 "vst1.32 {d22, d23}, [r0]!\n"
1162 "vst1.32 {d30, d31}, [r0]!\n"
1163 : // outputs
1164 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1165 [depth] "+r"(depth)
1166 : // inputs
1167 [accum_ptr] "r"(accum_ptr)
1168 : // clobbers
1169 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1170 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1171 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1172 "d28", "d29", "d30", "d31");
1173 }
1174 };
1175
1176 // Faster kernel contributed by ARM in 64bit form
1177 // (see NEON_64bit_GEMM_Float32_WithScalar_A53) then ported to 32bit code.
1178 // Tuned for A53.
1179 struct NEON_32bit_GEMM_Float32_WithScalar_A53 {
1180 typedef float OperandType;
1181 typedef float AccumulatorType;
1182 typedef KernelFormat<
1183 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1184 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1185 Format;
RunNEON_32bit_GEMM_Float32_WithScalar_A531186 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1187 AccumulatorType* accum_ptr, int depth) {
1188 asm volatile(
1189 // Load accumulators
1190 "mov r0, %[accum_ptr]\n"
1191 "vld1.32 {d8, d9}, [r0]!\n"
1192 "vld1.32 {d16, d17}, [r0]!\n"
1193 "vld1.32 {d24, d25}, [r0]!\n"
1194 "vld1.32 {d10, d11}, [r0]!\n"
1195 "vld1.32 {d18, d19}, [r0]!\n"
1196 "vld1.32 {d26, d27}, [r0]!\n"
1197 "vld1.32 {d12, d13}, [r0]!\n"
1198 "vld1.32 {d20, d21}, [r0]!\n"
1199 "vld1.32 {d28, d29}, [r0]!\n"
1200 "vld1.32 {d14, d15}, [r0]!\n"
1201 "vld1.32 {d22, d23}, [r0]!\n"
1202 "vld1.32 {d30, d31}, [r0]!\n"
1203
1204 // Overview of register layout:
1205 //
1206 // A 1x4 cell of Rhs is stored in d0--d1 (q0).
1207 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7
1208 // (q1--q3).
1209 // A 12x4 block of accumulators is stored in q4--q15.
1210 //
1211 // +-----+-----+-----+-----+
1212 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]|
1213 // +-----+-----+-----+-----+
1214 //
1215 // | | | | |
1216 //
1217 // Lhs | | | | |
1218 //
1219 // +--+- - - - - - +-----+-----+-----+-----+
1220 // |d2| | q4 | q5 | q6 | q7 |
1221 // |d2| | q4 | q5 | q6 | q7 |
1222 // |d3| | q4 | q5 | q6 | q7 |
1223 // |d3| | q4 | q5 | q6 | q7 |
1224 // +--+- - - - - - +-----+-----+-----+-----+
1225 // |d4| | q8 | q9 | q10 | q11 |
1226 // |d4| | q8 | q9 | q10 | q11 |
1227 // |d5| | q8 | q9 | q10 | q11 |
1228 // |d5| | q8 | q9 | q10 | q11 |
1229 // +--+ - - - - - - +-----+-----+-----+-----+
1230 // |d6| | q12 | q13 | q14 | q15 |
1231 // |d6| | q12 | q13 | q14 | q15 |
1232 // |d7| | q12 | q13 | q14 | q15 |
1233 // |d7| | q12 | q13 | q14 | q15 |
1234 // +--+- - - - - - +-----+-----+-----+-----+
1235 //
1236 // Accumulator
1237
1238 // Load Rhs cell
1239 "vldr d0, [%[rhs_ptr]]\n"
1240 "ldr r2, [%[rhs_ptr], #8]\n"
1241 "ldr r3, [%[rhs_ptr], #12]\n"
1242
1243 // Load 1st Lhs Cell
1244 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n"
1245
1246 GEMMLOWP_LABEL_LOOP
1247 ":\n"
1248
1249 "vldr d4, [%[lhs_ptr], #16]\n" // Load 1st half of 2nd Lhs cell
1250 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1251 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1252 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 2nd Lhs cell, part 1
1253 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1254 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 2nd Lhs cell, part 2
1255 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1256 "subs %[depth], #1\n"
1257
1258 "vldr d6, [%[lhs_ptr], #32]\n" // Load 1st half of 3rd Lhs cell
1259 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1260 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1261 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 3rd Lhs cell, part 1
1262 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1263 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 3rd Lhs cell, part 2
1264 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1265 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Move forward by 1 Rhs cell
1266
1267 "vldr d2, [%[lhs_ptr], #48]\n" // Load 1st half of 1st Lhs cell of next
1268 // iteration
1269 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1270 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1271 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 1st Lhs cell of next
1272 // iter, part 1
1273 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1274 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 1st Lhs cell of next
1275 // iter, part 2
1276 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1277 "add %[lhs_ptr], %[lhs_ptr], #48\n" // Move forward by 3 Lhs cells
1278
1279 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next
1280 // iteration
1281 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1282 // iteration
1283 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1284 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next
1285 // iteration, part 1
1286 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1287 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next
1288 // iteration, part 2
1289 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1290
1291 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block.
1292 "bne " GEMMLOWP_LABEL_LOOP
1293 "b\n"
1294
1295 // Store accumulators
1296 "mov r0, %[accum_ptr]\n"
1297 "vst1.32 {d8, d9}, [r0]!\n"
1298 "vst1.32 {d16, d17}, [r0]!\n"
1299 "vst1.32 {d24, d25}, [r0]!\n"
1300 "vst1.32 {d10, d11}, [r0]!\n"
1301 "vst1.32 {d18, d19}, [r0]!\n"
1302 "vst1.32 {d26, d27}, [r0]!\n"
1303 "vst1.32 {d12, d13}, [r0]!\n"
1304 "vst1.32 {d20, d21}, [r0]!\n"
1305 "vst1.32 {d28, d29}, [r0]!\n"
1306 "vst1.32 {d14, d15}, [r0]!\n"
1307 "vst1.32 {d22, d23}, [r0]!\n"
1308 "vst1.32 {d30, d31}, [r0]!\n"
1309 : // outputs
1310 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1311 [depth] "+r"(depth)
1312 : // inputs
1313 [accum_ptr] "r"(accum_ptr)
1314 : // clobbers
1315 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5",
1316 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16",
1317 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26",
1318 "d27", "d28", "d29", "d30", "d31");
1319 }
1320 };
1321
1322 struct NEON_32bit_GEMM_Float32_WithScalar_A53_depth2 {
1323 typedef float OperandType;
1324 typedef float AccumulatorType;
1325 typedef KernelFormat<
1326 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
1327 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
1328 Format;
RunNEON_32bit_GEMM_Float32_WithScalar_A53_depth21329 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1330 AccumulatorType* accum_ptr, int depth) {
1331 asm volatile(
1332 // Load accumulators
1333 "mov r0, %[accum_ptr]\n"
1334 "vld1.32 {d8, d9}, [r0]!\n"
1335 "vld1.32 {d16, d17}, [r0]!\n"
1336 "vld1.32 {d24, d25}, [r0]!\n"
1337 "vld1.32 {d10, d11}, [r0]!\n"
1338 "vld1.32 {d18, d19}, [r0]!\n"
1339 "vld1.32 {d26, d27}, [r0]!\n"
1340 "vld1.32 {d12, d13}, [r0]!\n"
1341 "vld1.32 {d20, d21}, [r0]!\n"
1342 "vld1.32 {d28, d29}, [r0]!\n"
1343 "vld1.32 {d14, d15}, [r0]!\n"
1344 "vld1.32 {d22, d23}, [r0]!\n"
1345 "vld1.32 {d30, d31}, [r0]!\n"
1346
1347 // Overview of register layout:
1348 //
1349 // A 1x4 cell of Rhs is stored in d0--d1 (q0).
1350 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7
1351 // (q1--q3).
1352 // A 12x4 block of accumulators is stored in q4--q15.
1353 //
1354 // +-----+-----+-----+-----+
1355 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]|
1356 // +-----+-----+-----+-----+
1357 //
1358 // | | | | |
1359 //
1360 // Lhs | | | | |
1361 //
1362 // +--+- - - - - - +-----+-----+-----+-----+
1363 // |d2| | q4 | q5 | q6 | q7 |
1364 // |d2| | q4 | q5 | q6 | q7 |
1365 // |d3| | q4 | q5 | q6 | q7 |
1366 // |d3| | q4 | q5 | q6 | q7 |
1367 // +--+- - - - - - +-----+-----+-----+-----+
1368 // |d4| | q8 | q9 | q10 | q11 |
1369 // |d4| | q8 | q9 | q10 | q11 |
1370 // |d5| | q8 | q9 | q10 | q11 |
1371 // |d5| | q8 | q9 | q10 | q11 |
1372 // +--+ - - - - - - +-----+-----+-----+-----+
1373 // |d6| | q12 | q13 | q14 | q15 |
1374 // |d6| | q12 | q13 | q14 | q15 |
1375 // |d7| | q12 | q13 | q14 | q15 |
1376 // |d7| | q12 | q13 | q14 | q15 |
1377 // +--+- - - - - - +-----+-----+-----+-----+
1378 //
1379 // Accumulator
1380
1381 // Load Rhs cell
1382 "vldr d0, [%[rhs_ptr]]\n"
1383 "ldr r2, [%[rhs_ptr], #8]\n"
1384 "ldr r3, [%[rhs_ptr], #12]\n"
1385
1386 // Load 1st Lhs Cell
1387 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n"
1388
1389 // Loop head - handling 2 levels of depth at once
1390 GEMMLOWP_LABEL_LOOP
1391 ":\n"
1392
1393 // Level of depth 1
1394
1395 "vldr d4, [%[lhs_ptr], #32]\n" // Load 1st half of 2nd Lhs cell
1396 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1397 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1398 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 2nd Lhs cell, part 1
1399 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1400 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 2nd Lhs cell, part 2
1401 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1402
1403 "vldr d6, [%[lhs_ptr], #64]\n" // Load 1st half of 3rd Lhs cell
1404 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1405 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1406 "ldr r2, [%[lhs_ptr], #72]\n" // Load 2nd half of 3rd Lhs cell, part 1
1407 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1408 "ldr r3, [%[lhs_ptr], #76]\n" // Load 2nd half of 3rd Lhs cell, part 2
1409 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1410
1411 "vldr d2, [%[lhs_ptr], #16]\n" // Load 1st half of 1st Lhs cell of next
1412 // iteration
1413 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1414 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1415 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 1st Lhs cell of next
1416 // iter, part 1
1417 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1418 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 1st Lhs cell of next
1419 // iter, part 2
1420 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1421
1422 "vldr d0, [%[rhs_ptr], #16]\n" // Load 1st half of Rhs cell of next
1423 // iteration
1424 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1425 // iteration
1426 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1427 "ldr r2, [%[rhs_ptr], #24]\n" // Load 2nd half of Rhs cell of next
1428 // iteration, part 1
1429 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1430 "ldr r3, [%[rhs_ptr], #28]\n" // Load 2nd half of Rhs cell of next
1431 // iteration, part 2
1432 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1433
1434 // Level of depth 2
1435 "vldr d4, [%[lhs_ptr], #48]\n" // Load 1st half of 2nd Lhs cell
1436 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell
1437 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0
1438 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 2nd Lhs cell, part 1
1439 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1
1440 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 2nd Lhs cell, part 2
1441 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2
1442 "subs %[depth], #2\n" // Decrement depth counter
1443
1444 "vldr d6, [%[lhs_ptr], #80]\n" // Load 1st half of 3rd Lhs cell
1445 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell
1446 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3
1447 "ldr r2, [%[lhs_ptr], #88]\n" // Load 2nd half of 3rd Lhs cell, part 1
1448 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0
1449 "ldr r3, [%[lhs_ptr], #92]\n" // Load 2nd half of 3rd Lhs cell, part 2
1450 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1
1451 "add %[rhs_ptr], %[rhs_ptr], #32\n" // Move forward by 1 Rhs cell
1452
1453 "vldr d2, [%[lhs_ptr], #96]\n" // Load 1st half of 1st Lhs cell of next
1454 // iteration
1455 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell
1456 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2
1457 "ldr r2, [%[lhs_ptr], #104]\n" // Load 2nd half of 1st Lhs cell of next
1458 // iter, part 1
1459 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0
1460 "ldr r3, [%[lhs_ptr], #108]\n" // Load 2nd half of 1st Lhs cell of next
1461 // iter, part 2
1462 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1
1463 "add %[lhs_ptr], %[lhs_ptr], #96\n" // Move forward by 3 Lhs cells
1464
1465 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next
1466 // iteration
1467 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next
1468 // iteration
1469 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3
1470 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next
1471 // iteration, part 1
1472 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2
1473 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next
1474 // iteration, part 2
1475 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3
1476
1477 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block.
1478 //"bne loop_%=\n"
1479 "bne " GEMMLOWP_LABEL_LOOP
1480 "b\n"
1481
1482 // Store accumulators
1483 "mov r0, %[accum_ptr]\n"
1484 "vst1.32 {d8, d9}, [r0]!\n"
1485 "vst1.32 {d16, d17}, [r0]!\n"
1486 "vst1.32 {d24, d25}, [r0]!\n"
1487 "vst1.32 {d10, d11}, [r0]!\n"
1488 "vst1.32 {d18, d19}, [r0]!\n"
1489 "vst1.32 {d26, d27}, [r0]!\n"
1490 "vst1.32 {d12, d13}, [r0]!\n"
1491 "vst1.32 {d20, d21}, [r0]!\n"
1492 "vst1.32 {d28, d29}, [r0]!\n"
1493 "vst1.32 {d14, d15}, [r0]!\n"
1494 "vst1.32 {d22, d23}, [r0]!\n"
1495 "vst1.32 {d30, d31}, [r0]!\n"
1496 : // outputs
1497 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1498 [depth] "+r"(depth)
1499 : // inputs
1500 [accum_ptr] "r"(accum_ptr)
1501 : // clobbers
1502 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5",
1503 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16",
1504 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26",
1505 "d27", "d28", "d29", "d30", "d31");
1506 }
1507 };
1508
1509 // This rotating variant performs well when permutations (vext) can be
1510 // dual-issued with arithmetic instructions.
1511 struct NEON_32bit_GEMM_Float32_MLA_Rotating {
1512 typedef float OperandType;
1513 typedef float AccumulatorType;
1514 typedef KernelFormat<
1515 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1516 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1517 Format;
RunNEON_32bit_GEMM_Float32_MLA_Rotating1518 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1519 AccumulatorType* accum_ptr, int depth) {
1520 asm volatile(
1521 // Load accumulators
1522 "mov r0, %[accum_ptr]\n"
1523 "vld1.32 {d8, d9}, [r0]!\n"
1524 "vld1.32 {d16, d17}, [r0]!\n"
1525 "vld1.32 {d24, d25}, [r0]!\n"
1526 "vld1.32 {d10, d11}, [r0]!\n"
1527 "vld1.32 {d18, d19}, [r0]!\n"
1528 "vld1.32 {d26, d27}, [r0]!\n"
1529 "vld1.32 {d12, d13}, [r0]!\n"
1530 "vld1.32 {d20, d21}, [r0]!\n"
1531 "vld1.32 {d28, d29}, [r0]!\n"
1532 "vld1.32 {d14, d15}, [r0]!\n"
1533 "vld1.32 {d22, d23}, [r0]!\n"
1534 "vld1.32 {d30, d31}, [r0]!\n"
1535
1536 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \
1537 "vtrn.32 q4, q5\n" \
1538 "vtrn.32 q6, q7\n" \
1539 "vswp d9, d12\n" \
1540 "vswp d11, d14\n" \
1541 "vtrn.32 q8, q9\n" \
1542 "vtrn.32 q10, q11\n" \
1543 "vswp d17, d20\n" \
1544 "vswp d19, d22\n" \
1545 "vtrn.32 q12, q13\n" \
1546 "vtrn.32 q14, q15\n" \
1547 "vswp d25, d28\n" \
1548 "vswp d27, d30\n"
1549
1550 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(a, b, c) \
1551 NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \
1552 "vext.32 q5, q5, q5, #" #a \
1553 "\n" \
1554 "vext.32 q6, q6, q6, #" #b \
1555 "\n" \
1556 "vext.32 q7, q7, q7, #" #c \
1557 "\n" \
1558 "vext.32 q9, q9, q9, #" #a \
1559 "\n" \
1560 "vext.32 q10, q10, q10, #" #b \
1561 "\n" \
1562 "vext.32 q11, q11, q11, #" #c \
1563 "\n" \
1564 "vext.32 q13, q13, q13, #" #a \
1565 "\n" \
1566 "vext.32 q14, q14, q14, #" #b \
1567 "\n" \
1568 "vext.32 q15, q15, q15, #" #c \
1569 "\n" NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS
1570
1571 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3)
1572
1573 //"loop_%=:\n"
1574 GEMMLOWP_LABEL_LOOP
1575 ":\n"
1576
1577 // Load 1 Rhs cell of size 1x4
1578 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1579
1580 // Load 3 Lhs cells of size 4x1 each
1581 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1582 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1583 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1584
1585 // Multiply-accumulate
1586 "vmla.f32 q4, q1, q0\n"
1587 "vmla.f32 q8, q2, q0\n"
1588 "vmla.f32 q12, q3, q0\n"
1589 "vext.f32 q0, q0, q0, #1\n"
1590 "vmla.f32 q5, q1, q0\n"
1591 "vmla.f32 q9, q2, q0\n"
1592 "vmla.f32 q13, q3, q0\n"
1593 "vext.f32 q0, q0, q0, #1\n"
1594 "vmla.f32 q6, q1, q0\n"
1595 "vmla.f32 q10, q2, q0\n"
1596 "vmla.f32 q14, q3, q0\n"
1597 "vext.f32 q0, q0, q0, #1\n"
1598 "vmla.f32 q7, q1, q0\n"
1599 "vmla.f32 q11, q2, q0\n"
1600 "vmla.f32 q15, q3, q0\n"
1601
1602 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1603 // level of depth.
1604 "subs %[depth], #1\n"
1605 //"bne loop_%=\n"
1606 "bne " GEMMLOWP_LABEL_LOOP
1607 "b\n"
1608
1609 // Store accumulators
1610 "mov r0, %[accum_ptr]\n"
1611
1612 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1)
1613
1614 "vst1.32 {d8, d9}, [r0]!\n"
1615 "vst1.32 {d16, d17}, [r0]!\n"
1616 "vst1.32 {d24, d25}, [r0]!\n"
1617 "vst1.32 {d10, d11}, [r0]!\n"
1618 "vst1.32 {d18, d19}, [r0]!\n"
1619 "vst1.32 {d26, d27}, [r0]!\n"
1620 "vst1.32 {d12, d13}, [r0]!\n"
1621 "vst1.32 {d20, d21}, [r0]!\n"
1622 "vst1.32 {d28, d29}, [r0]!\n"
1623 "vst1.32 {d14, d15}, [r0]!\n"
1624 "vst1.32 {d22, d23}, [r0]!\n"
1625 "vst1.32 {d30, d31}, [r0]!\n"
1626 : // outputs
1627 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1628 [depth] "+r"(depth)
1629 : // inputs
1630 [accum_ptr] "r"(accum_ptr)
1631 : // clobbers
1632 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1633 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1634 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1635 "d28", "d29", "d30", "d31");
1636 }
1637 };
1638
1639 // This rotating variant performs well when permutations (vext) can be
1640 // dual-issued with arithmetic instructions. It is relevant as the rotating
1641 // approach removes the need for multiply-with-scalar instructions, and ARMv7
1642 // FMA does not have a with-scalar variant.
1643 struct NEON_32bit_GEMM_Float32_FMA_Rotating {
1644 typedef float OperandType;
1645 typedef float AccumulatorType;
1646 typedef KernelFormat<
1647 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
1648 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> >
1649 Format;
RunNEON_32bit_GEMM_Float32_FMA_Rotating1650 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1651 AccumulatorType* accum_ptr, int depth) {
1652 asm volatile(
1653 // Load accumulators
1654 "mov r0, %[accum_ptr]\n"
1655 "vld1.32 {d8, d9}, [r0]!\n"
1656 "vld1.32 {d16, d17}, [r0]!\n"
1657 "vld1.32 {d24, d25}, [r0]!\n"
1658 "vld1.32 {d10, d11}, [r0]!\n"
1659 "vld1.32 {d18, d19}, [r0]!\n"
1660 "vld1.32 {d26, d27}, [r0]!\n"
1661 "vld1.32 {d12, d13}, [r0]!\n"
1662 "vld1.32 {d20, d21}, [r0]!\n"
1663 "vld1.32 {d28, d29}, [r0]!\n"
1664 "vld1.32 {d14, d15}, [r0]!\n"
1665 "vld1.32 {d22, d23}, [r0]!\n"
1666 "vld1.32 {d30, d31}, [r0]!\n"
1667
1668 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3)
1669
1670 //"loop_%=:\n"
1671 GEMMLOWP_LABEL_LOOP
1672 ":\n"
1673
1674 // Load 1 Rhs cell of size 1x4
1675 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n"
1676
1677 // Load 3 Lhs cells of size 4x1 each
1678 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
1679 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n"
1680 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n"
1681
1682 // Multiply-accumulate
1683 "vfma.f32 q4, q1, q0\n"
1684 "vfma.f32 q8, q2, q0\n"
1685 "vfma.f32 q12, q3, q0\n"
1686 "vext.f32 q0, q0, q0, #1\n"
1687 "vfma.f32 q5, q1, q0\n"
1688 "vfma.f32 q9, q2, q0\n"
1689 "vfma.f32 q13, q3, q0\n"
1690 "vext.f32 q0, q0, q0, #1\n"
1691 "vfma.f32 q6, q1, q0\n"
1692 "vfma.f32 q10, q2, q0\n"
1693 "vfma.f32 q14, q3, q0\n"
1694 "vext.f32 q0, q0, q0, #1\n"
1695 "vfma.f32 q7, q1, q0\n"
1696 "vfma.f32 q11, q2, q0\n"
1697 "vfma.f32 q15, q3, q0\n"
1698
1699 // Loop. Decrement loop index (depth) by 1, since we just handled 1
1700 // level of depth.
1701 "subs %[depth], #1\n"
1702 //"bne loop_%=\n"
1703 "bne " GEMMLOWP_LABEL_LOOP "b\n"
1704
1705 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1)
1706
1707 // Store accumulators
1708 "mov r0, %[accum_ptr]\n"
1709 "vst1.32 {d8, d9}, [r0]!\n"
1710 "vst1.32 {d16, d17}, [r0]!\n"
1711 "vst1.32 {d24, d25}, [r0]!\n"
1712 "vst1.32 {d10, d11}, [r0]!\n"
1713 "vst1.32 {d18, d19}, [r0]!\n"
1714 "vst1.32 {d26, d27}, [r0]!\n"
1715 "vst1.32 {d12, d13}, [r0]!\n"
1716 "vst1.32 {d20, d21}, [r0]!\n"
1717 "vst1.32 {d28, d29}, [r0]!\n"
1718 "vst1.32 {d14, d15}, [r0]!\n"
1719 "vst1.32 {d22, d23}, [r0]!\n"
1720 "vst1.32 {d30, d31}, [r0]!\n"
1721 : // outputs
1722 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1723 [depth] "+r"(depth)
1724 : // inputs
1725 [accum_ptr] "r"(accum_ptr)
1726 : // clobbers
1727 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
1728 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
1729 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
1730 "d28", "d29", "d30", "d31");
1731 }
1732 };
1733
1734 #endif // __arm__
1735
1736 #ifdef __aarch64__
1737
1738 // This is the current standard kernel in gemmlowp, see:
1739 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L646
1740 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators {
1741 typedef std::uint8_t OperandType;
1742 typedef std::uint32_t AccumulatorType;
1743 typedef KernelFormat<
1744 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
1745 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
1746 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators1747 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
1748 AccumulatorType* accum_ptr, int depth) {
1749 asm volatile(
1750 // Load 1 Rhs cell of size 2x8
1751 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
1752 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
1753
1754 // Load 3 Lhs cells of size 4x2 each
1755 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
1756 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
1757 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
1758
1759 "subs %w[depth], %w[depth], #2\n"
1760
1761 // Load accumulators
1762 "mov x0, %[accum_ptr]\n"
1763 "ld1 {v8.16b}, [x0], #16\n"
1764 "ld1 {v16.16b}, [x0], #16\n"
1765 "ld1 {v24.16b}, [x0], #16\n"
1766 "ld1 {v9.16b}, [x0], #16\n"
1767 "ld1 {v17.16b}, [x0], #16\n"
1768 "ld1 {v25.16b}, [x0], #16\n"
1769 "ld1 {v10.16b}, [x0], #16\n"
1770 "ld1 {v18.16b}, [x0], #16\n"
1771 "ld1 {v26.16b}, [x0], #16\n"
1772 "ld1 {v11.16b}, [x0], #16\n"
1773 "ld1 {v19.16b}, [x0], #16\n"
1774 "ld1 {v27.16b}, [x0], #16\n"
1775 "ld1 {v12.16b}, [x0], #16\n"
1776 "ld1 {v20.16b}, [x0], #16\n"
1777 "ld1 {v28.16b}, [x0], #16\n"
1778 "ld1 {v13.16b}, [x0], #16\n"
1779 "ld1 {v21.16b}, [x0], #16\n"
1780 "ld1 {v29.16b}, [x0], #16\n"
1781 "ld1 {v14.16b}, [x0], #16\n"
1782 "ld1 {v22.16b}, [x0], #16\n"
1783 "ld1 {v30.16b}, [x0], #16\n"
1784 "ld1 {v15.16b}, [x0], #16\n"
1785 "ld1 {v23.16b}, [x0], #16\n"
1786 "ld1 {v31.16b}, [x0], #16\n"
1787
1788 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
1789
1790 //"loop_%=:\n"
1791 GEMMLOWP_LABEL_LOOP
1792 ":\n"
1793
1794 // Overview of register layout:
1795 //
1796 // A 2x8 block of 2 2x4 cells of Rhs is stored in 16bit in v0--v1.
1797 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in v2--v4.
1798 // A 12x8 block of accumulators is stored in 32bit in v8--v31.
1799 //
1800 // +--------+--------+-----+--------+--------+
1801 // |v0.h[0] |v0.h[1] | ... |v1.h[2] |v1.h[3] |
1802 // Rhs +--------+--------+-----+--------+--------+
1803 // |v0.h[4] |v0.h[5] | ... |v1.h[6] |v1.h[7] |
1804 // +--------+--------+-----+--------+--------+
1805 //
1806 // | | | | | |
1807 //
1808 // Lhs | | | | | |
1809 //
1810 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1811 // |v2.h[0]|v2.h[4]| |v8.s[0] |v9.s[0] | ... |v14.s[0]|v15.s[0]|
1812 // |v2.h[1]|v2.h[5]| |v8.s[1] |v9.s[1] | ... |v14.s[1]|v15.s[1]|
1813 // |v2.h[2]|v2.h[6]| |v8.s[2] |v9.s[2] | ... |v14.s[2]|v15.s[2]|
1814 // |v2.h[3]|v2.h[7]| |v8.s[3] |v9.s[3] | ... |v14.s[3]|v15.s[3]|
1815 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1816 // |v3.h[0]|v3.h[4]| |v16.s[0]|v17.s[0]| ... |v22.s[0]|v23.s[0]|
1817 // |v3.h[1]|v3.h[5]| |v16.s[1]|v17.s[1]| ... |v22.s[1]|v23.s[1]|
1818 // |v3.h[2]|v3.h[6]| |v16.s[2]|v17.s[2]| ... |v22.s[2]|v23.s[2]|
1819 // |v3.h[3]|v3.h[7]| |v16.s[3]|v17.s[3]| ... |v22.s[3]|v23.s[3]|
1820 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1821 // |v4.h[0]|v4.h[4]| |v24.s[0]|v25.s[0]| ... |v30.s[0]|v31.s[0]|
1822 // |v4.h[1]|v4.h[5]| |v24.s[1]|v25.s[1]| ... |v30.s[1]|v31.s[1]|
1823 // |v4.h[2]|v4.h[6]| |v24.s[2]|v25.s[2]| ... |v30.s[2]|v31.s[2]|
1824 // |v4.h[3]|v4.h[7]| |v24.s[3]|v25.s[3]| ... |v30.s[3]|v31.s[3]|
1825 // +-------+-------+ - - +--------+--------+-----+--------+--------+
1826 //
1827 // Accumulator
1828
1829 // Expand Lhs/Rhs cells to 16 bit.
1830 "uxtl v0.8h, v5.8b\n"
1831 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
1832 "uxtl v1.8h, v6.8b\n"
1833 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
1834 "uxtl v2.8h, v2.8b\n"
1835 "uxtl v3.8h, v3.8b\n"
1836 "uxtl v4.8h, v4.8b\n"
1837
1838 // Multiply-accumulate, top third
1839 "umlal v8.4s, v2.4h, v0.h[0]\n"
1840 "umlal v9.4s, v2.4h, v0.h[1]\n"
1841 "umlal v10.4s, v2.4h, v0.h[2]\n"
1842 "umlal v11.4s, v2.4h, v0.h[3]\n"
1843 "umlal v12.4s, v2.4h, v1.h[0]\n"
1844 "umlal v13.4s, v2.4h, v1.h[1]\n"
1845 "umlal v14.4s, v2.4h, v1.h[2]\n"
1846 "umlal v15.4s, v2.4h, v1.h[3]\n"
1847 "umlal2 v8.4s, v2.8h, v0.h[4]\n"
1848 "umlal2 v9.4s, v2.8h, v0.h[5]\n"
1849 "umlal2 v10.4s, v2.8h, v0.h[6]\n"
1850 "umlal2 v11.4s, v2.8h, v0.h[7]\n"
1851 "umlal2 v12.4s, v2.8h, v1.h[4]\n"
1852 "umlal2 v13.4s, v2.8h, v1.h[5]\n"
1853 "umlal2 v14.4s, v2.8h, v1.h[6]\n"
1854 "umlal2 v15.4s, v2.8h, v1.h[7]\n"
1855 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
1856
1857 // Multiply-accumulate, middle third
1858 "umlal v16.4s, v3.4h, v0.h[0]\n"
1859 "umlal v17.4s, v3.4h, v0.h[1]\n"
1860 "umlal v18.4s, v3.4h, v0.h[2]\n"
1861 "umlal v19.4s, v3.4h, v0.h[3]\n"
1862 "umlal v20.4s, v3.4h, v1.h[0]\n"
1863 "umlal v21.4s, v3.4h, v1.h[1]\n"
1864 "umlal v22.4s, v3.4h, v1.h[2]\n"
1865 "umlal v23.4s, v3.4h, v1.h[3]\n"
1866 "umlal2 v16.4s, v3.8h, v0.h[4]\n"
1867 "umlal2 v17.4s, v3.8h, v0.h[5]\n"
1868 "umlal2 v18.4s, v3.8h, v0.h[6]\n"
1869 "umlal2 v19.4s, v3.8h, v0.h[7]\n"
1870 "umlal2 v20.4s, v3.8h, v1.h[4]\n"
1871 "umlal2 v21.4s, v3.8h, v1.h[5]\n"
1872 "umlal2 v22.4s, v3.8h, v1.h[6]\n"
1873 "umlal2 v23.4s, v3.8h, v1.h[7]\n"
1874 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
1875
1876 "subs %w[depth], %w[depth], #2\n"
1877
1878 // Multiply-accumulate, bottom third
1879 "umlal v24.4s, v4.4h, v0.h[0]\n"
1880 "umlal v25.4s, v4.4h, v0.h[1]\n"
1881 "umlal v26.4s, v4.4h, v0.h[2]\n"
1882 "umlal v27.4s, v4.4h, v0.h[3]\n"
1883 "umlal v28.4s, v4.4h, v1.h[0]\n"
1884 "umlal v29.4s, v4.4h, v1.h[1]\n"
1885 "umlal v30.4s, v4.4h, v1.h[2]\n"
1886 "umlal v31.4s, v4.4h, v1.h[3]\n"
1887 "umlal2 v24.4s, v4.8h, v0.h[4]\n"
1888 "umlal2 v25.4s, v4.8h, v0.h[5]\n"
1889 "umlal2 v26.4s, v4.8h, v0.h[6]\n"
1890 "umlal2 v27.4s, v4.8h, v0.h[7]\n"
1891 "umlal2 v28.4s, v4.8h, v1.h[4]\n"
1892 "umlal2 v29.4s, v4.8h, v1.h[5]\n"
1893 "umlal2 v30.4s, v4.8h, v1.h[6]\n"
1894 "umlal2 v31.4s, v4.8h, v1.h[7]\n"
1895 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
1896
1897 "bne " GEMMLOWP_LABEL_LOOP "b\n"
1898
1899 GEMMLOWP_LABEL_AFTER_LOOP
1900 ":\n"
1901
1902 // Expand Lhs/Rhs cells to 16 bit.
1903 "uxtl v0.8h, v5.8b\n"
1904 "uxtl v1.8h, v6.8b\n"
1905 "uxtl v2.8h, v2.8b\n"
1906 "uxtl v3.8h, v3.8b\n"
1907 "uxtl v4.8h, v4.8b\n"
1908
1909 // Multiply-accumulate, level of depth 0
1910 "umlal v8.4s, v2.4h, v0.h[0]\n"
1911 "umlal v9.4s, v2.4h, v0.h[1]\n"
1912 "umlal v10.4s, v2.4h, v0.h[2]\n"
1913 "umlal v11.4s, v2.4h, v0.h[3]\n"
1914 "umlal v12.4s, v2.4h, v1.h[0]\n"
1915 "umlal v13.4s, v2.4h, v1.h[1]\n"
1916 "umlal v14.4s, v2.4h, v1.h[2]\n"
1917 "umlal v15.4s, v2.4h, v1.h[3]\n"
1918 "umlal v16.4s, v3.4h, v0.h[0]\n"
1919 "umlal v17.4s, v3.4h, v0.h[1]\n"
1920 "umlal v18.4s, v3.4h, v0.h[2]\n"
1921 "umlal v19.4s, v3.4h, v0.h[3]\n"
1922 "umlal v20.4s, v3.4h, v1.h[0]\n"
1923 "umlal v21.4s, v3.4h, v1.h[1]\n"
1924 "umlal v22.4s, v3.4h, v1.h[2]\n"
1925 "umlal v23.4s, v3.4h, v1.h[3]\n"
1926 "umlal v24.4s, v4.4h, v0.h[0]\n"
1927 "umlal v25.4s, v4.4h, v0.h[1]\n"
1928 "umlal v26.4s, v4.4h, v0.h[2]\n"
1929 "umlal v27.4s, v4.4h, v0.h[3]\n"
1930 "umlal v28.4s, v4.4h, v1.h[0]\n"
1931 "umlal v29.4s, v4.4h, v1.h[1]\n"
1932 "umlal v30.4s, v4.4h, v1.h[2]\n"
1933 "umlal v31.4s, v4.4h, v1.h[3]\n"
1934
1935 // Multiply-accumulate, level of depth 1
1936 "umlal2 v8.4s, v2.8h, v0.h[4]\n"
1937 "umlal2 v9.4s, v2.8h, v0.h[5]\n"
1938 "umlal2 v10.4s, v2.8h, v0.h[6]\n"
1939 "umlal2 v11.4s, v2.8h, v0.h[7]\n"
1940 "umlal2 v12.4s, v2.8h, v1.h[4]\n"
1941 "umlal2 v13.4s, v2.8h, v1.h[5]\n"
1942 "umlal2 v14.4s, v2.8h, v1.h[6]\n"
1943 "umlal2 v15.4s, v2.8h, v1.h[7]\n"
1944 "umlal2 v16.4s, v3.8h, v0.h[4]\n"
1945 "umlal2 v17.4s, v3.8h, v0.h[5]\n"
1946 "umlal2 v18.4s, v3.8h, v0.h[6]\n"
1947 "umlal2 v19.4s, v3.8h, v0.h[7]\n"
1948 "umlal2 v20.4s, v3.8h, v1.h[4]\n"
1949 "umlal2 v21.4s, v3.8h, v1.h[5]\n"
1950 "umlal2 v22.4s, v3.8h, v1.h[6]\n"
1951 "umlal2 v23.4s, v3.8h, v1.h[7]\n"
1952 "umlal2 v24.4s, v4.8h, v0.h[4]\n"
1953 "umlal2 v25.4s, v4.8h, v0.h[5]\n"
1954 "umlal2 v26.4s, v4.8h, v0.h[6]\n"
1955 "umlal2 v27.4s, v4.8h, v0.h[7]\n"
1956 "umlal2 v28.4s, v4.8h, v1.h[4]\n"
1957 "umlal2 v29.4s, v4.8h, v1.h[5]\n"
1958 "umlal2 v30.4s, v4.8h, v1.h[6]\n"
1959 "umlal2 v31.4s, v4.8h, v1.h[7]\n"
1960
1961 // Store accumulators
1962 "mov x0, %[accum_ptr]\n"
1963 "st1 {v8.16b}, [x0], #16\n"
1964 "st1 {v16.16b}, [x0], #16\n"
1965 "st1 {v24.16b}, [x0], #16\n"
1966 "st1 {v9.16b}, [x0], #16\n"
1967 "st1 {v17.16b}, [x0], #16\n"
1968 "st1 {v25.16b}, [x0], #16\n"
1969 "st1 {v10.16b}, [x0], #16\n"
1970 "st1 {v18.16b}, [x0], #16\n"
1971 "st1 {v26.16b}, [x0], #16\n"
1972 "st1 {v11.16b}, [x0], #16\n"
1973 "st1 {v19.16b}, [x0], #16\n"
1974 "st1 {v27.16b}, [x0], #16\n"
1975 "st1 {v12.16b}, [x0], #16\n"
1976 "st1 {v20.16b}, [x0], #16\n"
1977 "st1 {v28.16b}, [x0], #16\n"
1978 "st1 {v13.16b}, [x0], #16\n"
1979 "st1 {v21.16b}, [x0], #16\n"
1980 "st1 {v29.16b}, [x0], #16\n"
1981 "st1 {v14.16b}, [x0], #16\n"
1982 "st1 {v22.16b}, [x0], #16\n"
1983 "st1 {v30.16b}, [x0], #16\n"
1984 "st1 {v15.16b}, [x0], #16\n"
1985 "st1 {v23.16b}, [x0], #16\n"
1986 "st1 {v31.16b}, [x0], #16\n"
1987 : // outputs
1988 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1989 [depth] "+r"(depth)
1990 : // inputs
1991 [accum_ptr] "r"(accum_ptr)
1992 : // clobbers
1993 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
1994 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
1995 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
1996 "v28", "v29", "v30", "v31");
1997 }
1998 };
1999
2000 // Faster kernel by ARM. Not expanding operands before multiplication.
2001 // Tuned for A57. Compare to
2002 // NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand
2003 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57 {
2004 typedef std::uint8_t OperandType;
2005 typedef std::uint32_t AccumulatorType;
2006 typedef KernelFormat<
2007 KernelSideFormat<CellFormat<5, 16, CellOrder::WidthMajor>, 1>,
2008 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
2009 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A572010 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2011 AccumulatorType* accum_ptr, int depth) {
2012 static const int kLhsWidth = Format::Lhs::kWidth;
2013 static const int kRhsWidth = Format::Rhs::kWidth;
2014 AccumulatorType rowmajor_accumulator_buffer[kLhsWidth * kRhsWidth];
2015 asm volatile(
2016 // Clear aggregators
2017 "dup v12.4s, wzr\n"
2018 "dup v13.4s, wzr\n"
2019 "dup v14.4s, wzr\n"
2020 "dup v15.4s, wzr\n"
2021 "dup v16.4s, wzr\n"
2022 "dup v17.4s, wzr\n"
2023 "dup v18.4s, wzr\n"
2024 "dup v19.4s, wzr\n"
2025 "dup v20.4s, wzr\n"
2026 "dup v21.4s, wzr\n"
2027 "dup v22.4s, wzr\n"
2028 "dup v23.4s, wzr\n"
2029 "dup v24.4s, wzr\n"
2030 "dup v25.4s, wzr\n"
2031 "dup v26.4s, wzr\n"
2032 "dup v27.4s, wzr\n"
2033 "dup v28.4s, wzr\n"
2034 "dup v29.4s, wzr\n"
2035 "dup v30.4s, wzr\n"
2036 "dup v31.4s, wzr\n"
2037
2038 GEMMLOWP_LABEL_LOOP
2039 ":\n"
2040
2041 // Overview of register layout:
2042 //
2043 // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
2044 // A 5x16 block of Lhs is cycled through v4 and v5 in 8 bit.
2045 //
2046 // A 4x5 block of aggregators is stored in v12-v31 (as 4x32 bit
2047 // components which would need to be added at the end)
2048 //
2049 // The Lhs vectors are multiplied by the Rhs vectors with a widening
2050 // multiply to produce an intermediate result which is stored in
2051 // v6-v11. Each intermediate result is 8x16 bits so this happens
2052 // twice for each Lhs/Rhs combination (once with UMULL for elements
2053 // 0-7 and once with UMULL2 for elements 8-15).
2054 //
2055 // UADALP is used to accumulate these intermediate results into the
2056 // result aggregators.
2057 //
2058 //
2059 //
2060 // +--------+--------+--------+--------+
2061 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
2062 // Rhs +--------+--------+--------+--------+
2063 // | ... | ... | ... | ... |
2064 // +--------+--------+--------+--------|
2065 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
2066 // +--------+--------+--------+--------+
2067 //
2068 // | | | | |
2069 //
2070 // Lhs | | | | |
2071 //
2072 // +-------+-----+--------+ - - +--------+--------+--------+--------+
2073 // |v4.b[0]| ... |v4.b[15]| | v12.4s | v13.4s | v14.4s | v15.4s |
2074 // |v5.b[0]| ... |v5.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
2075 // |v4.b[0]| ... |v4.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
2076 // |v5.b[0]| ... |v5.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
2077 // |v4.b[0]| ... |v4.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
2078 // +-------+--------------+ - - +--------+--------+--------+--------+
2079 //
2080 // Accumulator
2081 //
2082 //
2083 // Further possible optimisations (not tried):
2084 // - Move early loads into previous iteration (see Float32_WithScalar
2085 // for example). - Unroll loop 2x to alternate more smoothly between
2086 // v4 and v5. - A different number of temporary registers might work
2087 // better. - Pairing umull with corresponding umull2 might allow
2088 // better
2089 // register loading (e.g. at the start of the loop)
2090 // - Interleaving umull{2} and uadalp even more aggressively might
2091 // help, (not sure about latency vs. dispatch rate).
2092 //
2093 //
2094 // Start loading Rhs - further loads are interleaved amongst the
2095 // multiplies for better dispatch on A57.
2096 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2097
2098 // Load first Lhs vector - further loads are interleaved amongst the
2099 // multiplies
2100 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2101
2102 "umull v6.8h, v0.8b, v4.8b\n"
2103 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // 2nd RHS element
2104 "umull v7.8h, v1.8b, v4.8b\n"
2105 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" // 3rd RHS element
2106 "umull v8.8h, v2.8b, v4.8b\n"
2107 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" // 4th RHS element
2108 "umull v9.8h, v3.8b, v4.8b\n"
2109 "umull2 v10.8h, v0.16b, v4.16b\n"
2110 "umull2 v11.8h, v1.16b, v4.16b\n"
2111 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element
2112
2113 "uadalp v12.4s, v6.8h\n"
2114 "umull2 v6.8h, v2.16b, v4.16b\n"
2115 "uadalp v13.4s, v7.8h\n"
2116 "umull2 v7.8h, v3.16b, v4.16b\n"
2117 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 1st LHS element done - Reuse v4
2118 // for 3rd LHS element
2119 "uadalp v14.4s, v8.8h\n"
2120 "umull v8.8h, v0.8b, v5.8b\n"
2121 "uadalp v15.4s, v9.8h\n"
2122 "umull v9.8h, v1.8b, v5.8b\n"
2123 "uadalp v12.4s, v10.8h\n"
2124 "umull v10.8h, v2.8b, v5.8b\n"
2125 "uadalp v13.4s, v11.8h\n"
2126 "umull v11.8h, v3.8b, v5.8b\n"
2127
2128 "uadalp v14.4s, v6.8h\n"
2129 "umull2 v6.8h, v0.16b, v5.16b\n"
2130 "uadalp v15.4s, v7.8h\n"
2131 "umull2 v7.8h, v1.16b, v5.16b\n"
2132 "uadalp v16.4s, v8.8h\n"
2133 "umull2 v8.8h, v2.16b, v5.16b\n"
2134 "uadalp v17.4s, v9.8h\n"
2135 "umull2 v9.8h, v3.16b, v5.16b\n"
2136 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element done - Reuse v5
2137 // for 4th LHS element
2138 "uadalp v18.4s, v10.8h\n"
2139 "umull v10.8h, v0.8b, v4.8b\n"
2140 "uadalp v19.4s, v11.8h\n"
2141 "umull v11.8h, v1.8b, v4.8b\n"
2142
2143 "uadalp v16.4s, v6.8h\n"
2144 "umull v6.8h, v2.8b, v4.8b\n"
2145 "uadalp v17.4s, v7.8h\n"
2146 "umull v7.8h, v3.8b, v4.8b\n"
2147 "uadalp v18.4s, v8.8h\n"
2148 "umull2 v8.8h, v0.16b, v4.16b\n"
2149 "uadalp v19.4s, v9.8h\n"
2150 "umull2 v9.8h, v1.16b, v4.16b\n"
2151 "uadalp v20.4s, v10.8h\n"
2152 "umull2 v10.8h, v2.16b, v4.16b\n"
2153 "uadalp v21.4s, v11.8h\n"
2154 "umull2 v11.8h, v3.16b, v4.16b\n"
2155 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 3rd LHS element done - Reuse v4
2156 // for 5th LHS element
2157
2158 "uadalp v22.4s, v6.8h\n"
2159 "umull v6.8h, v0.8b, v5.8b\n"
2160 "uadalp v23.4s, v7.8h\n"
2161 "umull v7.8h, v1.8b, v5.8b\n"
2162 "uadalp v20.4s, v8.8h\n"
2163 "umull v8.8h, v2.8b, v5.8b\n"
2164 "uadalp v21.4s, v9.8h\n"
2165 "umull v9.8h, v3.8b, v5.8b\n"
2166 "uadalp v22.4s, v10.8h\n"
2167 "umull2 v10.8h, v0.16b, v5.16b\n"
2168 "uadalp v23.4s, v11.8h\n"
2169 "umull2 v11.8h, v1.16b, v5.16b\n"
2170
2171 "uadalp v24.4s, v6.8h\n"
2172 "umull2 v6.8h, v2.16b, v5.16b\n"
2173 "uadalp v25.4s, v7.8h\n"
2174 "umull2 v7.8h, v3.16b, v5.16b\n"
2175 "uadalp v26.4s, v8.8h\n"
2176 "umull v8.8h, v0.8b, v4.8b\n"
2177 "uadalp v27.4s, v9.8h\n"
2178 "umull v9.8h, v1.8b, v4.8b\n"
2179 "uadalp v24.4s, v10.8h\n"
2180 "umull v10.8h, v2.8b, v4.8b\n"
2181 "uadalp v25.4s, v11.8h\n"
2182 "umull v11.8h, v3.8b, v4.8b\n"
2183
2184 "uadalp v26.4s, v6.8h\n"
2185 "umull2 v6.8h, v0.16b, v4.16b\n"
2186 "uadalp v27.4s, v7.8h\n"
2187 "umull2 v7.8h, v1.16b, v4.16b\n"
2188 "uadalp v28.4s, v8.8h\n"
2189 "umull2 v8.8h, v2.16b, v4.16b\n"
2190 "uadalp v29.4s, v9.8h\n"
2191 "umull2 v9.8h, v3.16b, v4.16b\n"
2192 "uadalp v30.4s, v10.8h\n"
2193 "uadalp v31.4s, v11.8h\n"
2194
2195 "uadalp v28.4s, v6.8h\n"
2196 "uadalp v29.4s, v7.8h\n"
2197 // Loop. Decrement loop index (depth) by 16, since we just handled
2198 // 16 levels of depth. Do this subs a bit before the end of the loop
2199 // for better dispatch on A57.
2200 "subs %w[depth], %w[depth], #16\n"
2201 "uadalp v30.4s, v8.8h\n"
2202 "uadalp v31.4s, v9.8h\n"
2203
2204 "bne " GEMMLOWP_LABEL_LOOP
2205 "b\n"
2206
2207 // Reduce aggregators horizontally
2208 "addp v0.4s, v12.4s, v13.4s\n"
2209 "addp v1.4s, v14.4s, v15.4s\n"
2210 "addp v2.4s, v16.4s, v17.4s\n"
2211 "addp v3.4s, v18.4s, v19.4s\n"
2212 "addp v4.4s, v20.4s, v21.4s\n"
2213 "addp v5.4s, v22.4s, v23.4s\n"
2214 "addp v6.4s, v24.4s, v25.4s\n"
2215 "addp v7.4s, v26.4s, v27.4s\n"
2216 "addp v8.4s, v28.4s, v29.4s\n"
2217 "addp v9.4s, v30.4s, v31.4s\n"
2218
2219 "addp v10.4s, v0.4s, v1.4s\n"
2220 "addp v11.4s, v2.4s, v3.4s\n"
2221 "addp v12.4s, v4.4s, v5.4s\n"
2222 "addp v13.4s, v6.4s, v7.4s\n"
2223 "addp v14.4s, v8.4s, v9.4s\n"
2224
2225 "mov x0, %[rowmajor_accumulator_buffer]\n"
2226 "st1 {v10.16b}, [x0], #16\n"
2227 "st1 {v11.16b}, [x0], #16\n"
2228 "st1 {v12.16b}, [x0], #16\n"
2229 "st1 {v13.16b}, [x0], #16\n"
2230 "st1 {v14.16b}, [x0], #16\n"
2231 : // outputs
2232 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2233 [depth] "+r"(depth)
2234 : // inputs
2235 [rowmajor_accumulator_buffer] "r"(rowmajor_accumulator_buffer)
2236 : // clobbers
2237 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2238 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2239 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2240 "v28", "v29", "v30", "v31");
2241
2242 // accumulate row-major accumulators into global (column-major) accumulators
2243 for (int l = 0; l < kLhsWidth; l++) {
2244 for (int r = 0; r < kRhsWidth; r++) {
2245 accum_ptr[l + kLhsWidth * r] +=
2246 rowmajor_accumulator_buffer[r + l * kRhsWidth];
2247 }
2248 }
2249 }
2250 };
2251
2252 // Fast kernel operating on int8 operands.
2253 // It is assumed that one of the two int8 operands only takes values
2254 // in [-127, 127], while the other may freely range in [-128, 127].
2255 // The issue with both operands taking the value -128 is that:
2256 // -128*-128 + -128*-128 == -32768 overflows int16.
2257 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
2258 // range. That is the basic idea of this kernel.
2259 struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
2260 typedef std::int8_t OperandType;
2261 typedef std::int32_t AccumulatorType;
2262 typedef KernelFormat<
2263 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
2264 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
2265 Format;
RunNEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits2266 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2267 AccumulatorType* accum_ptr, int depth) {
2268 std::size_t start_depth = 123;
2269 std::size_t run_depth = depth;
2270 std::size_t dst_col_stride = 4;
2271 AccumulatorType* dst_ptr = accum_ptr;
2272 asm volatile(
2273 // Overview of register layout:
2274 //
2275 // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
2276 // A 4x16 block of Lhs is stored in 8 bit in v4--v7.
2277 //
2278 // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit
2279 // components which need to be horizontally-added at the end)
2280 //
2281 // The Lhs vectors are multiplied by the Rhs vectors with a widening
2282 // multiply over the 8 first levels of depth, producing int16x8
2283 // vectors of products for each position in the accumulator matrix.
2284 // Here comes the special trick: since the operands are signed int8,
2285 // their range being [ -2^7 , 2^7 ), their products are in range
2286 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
2287 // without any risk of overflowing int16.
2288 // We thus proceed with the 8 next levels of depth, multiplying
2289 // again Lhs by Rhs, accumulating into this existing int16x8 vector.
2290 //
2291 // Only then, having processed 16 levels of depth, do we need to
2292 // horizontally add these int16x8 accumulators into the final
2293 // int32x4 accumulators.
2294 //
2295 // As we do not have enough registers to store all 16 int16x8
2296 // temporary-16bit-accumulators, we have them cycle through v8--v15.
2297 //
2298 //
2299 // Register layout (ignoring the v8--v15 temporary 16bit accumulators):
2300 //
2301 // +--------+--------+--------+--------+
2302 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
2303 // Rhs +--------+--------+--------+--------+
2304 // | ... | ... | ... | ... |
2305 // +--------+--------+--------+--------|
2306 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
2307 // +--------+--------+--------+--------+
2308 //
2309 // | | | | |
2310 //
2311 // Lhs | | | | |
2312 //
2313 // +-------+-----+--------+ - - +--------+--------+--------+--------+
2314 // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s |
2315 // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s |
2316 // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s |
2317 // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s |
2318 // +-------+--------------+ - - +--------+--------+--------+--------+
2319 //
2320 // Accumulator
2321 //
2322
2323 // Clear accumulators
2324 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2325 "dup v16.4s, wzr\n"
2326 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2327 "dup v17.4s, wzr\n"
2328 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2329 "dup v18.4s, wzr\n"
2330 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2331 "dup v19.4s, wzr\n"
2332 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2333 "dup v20.4s, wzr\n"
2334 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2335 "dup v21.4s, wzr\n"
2336 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2337 "dup v22.4s, wzr\n"
2338 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2339 "dup v23.4s, wzr\n"
2340 "subs %[run_depth], %[run_depth], #16\n"
2341 "dup v24.4s, wzr\n"
2342 "mov x0, %[dst_ptr]\n"
2343 "dup v25.4s, wzr\n"
2344 "dup v26.4s, wzr\n"
2345 "dup v27.4s, wzr\n"
2346 "dup v28.4s, wzr\n"
2347 "dup v29.4s, wzr\n"
2348 "dup v30.4s, wzr\n"
2349 "dup v31.4s, wzr\n"
2350
2351 "smull v12.8h, v0.8b, v4.8b\n"
2352 "smull v13.8h, v1.8b, v4.8b\n"
2353 "smull v14.8h, v0.8b, v5.8b\n"
2354 "smull v15.8h, v1.8b, v5.8b\n"
2355 "smlal2 v12.8h, v0.16b, v4.16b\n"
2356 "smlal2 v13.8h, v1.16b, v4.16b\n"
2357 "smlal2 v14.8h, v0.16b, v5.16b\n"
2358 "smlal2 v15.8h, v1.16b, v5.16b\n"
2359
2360 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
2361
2362 GEMMLOWP_LABEL_LOOP
2363 ":\n"
2364
2365 "subs %[run_depth], %[run_depth], #16\n"
2366
2367 "sadalp v16.4s, v12.8h\n"
2368 "smull v12.8h, v0.8b, v6.8b\n"
2369 "sadalp v17.4s, v13.8h\n"
2370 "smull v13.8h, v0.8b, v7.8b\n"
2371 "sadalp v20.4s, v14.8h\n"
2372 "smull v14.8h, v1.8b, v6.8b\n"
2373 "sadalp v21.4s, v15.8h\n"
2374 "smull v15.8h, v1.8b, v7.8b\n"
2375 "smlal2 v12.8h, v0.16b, v6.16b\n"
2376 "smlal2 v13.8h, v0.16b, v7.16b\n"
2377 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2378 "smlal2 v14.8h, v1.16b, v6.16b\n"
2379 "smlal2 v15.8h, v1.16b, v7.16b\n"
2380 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
2381 "sadalp v24.4s, v12.8h\n"
2382 "smull v12.8h, v2.8b, v4.8b\n"
2383 "sadalp v28.4s, v13.8h\n"
2384 "smull v13.8h, v3.8b, v4.8b\n"
2385 "sadalp v25.4s, v14.8h\n"
2386 "smull v14.8h, v2.8b, v5.8b\n"
2387 "sadalp v29.4s, v15.8h\n"
2388 "smull v15.8h, v3.8b, v5.8b\n"
2389 "smlal2 v12.8h, v2.16b, v4.16b\n"
2390 "smlal2 v13.8h, v3.16b, v4.16b\n"
2391 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
2392 "smlal2 v14.8h, v2.16b, v5.16b\n"
2393 "smlal2 v15.8h, v3.16b, v5.16b\n"
2394 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
2395 "sadalp v18.4s, v12.8h\n"
2396 "smull v12.8h, v2.8b, v6.8b\n"
2397 "sadalp v19.4s, v13.8h\n"
2398 "smull v13.8h, v2.8b, v7.8b\n"
2399 "sadalp v22.4s, v14.8h\n"
2400 "smull v14.8h, v3.8b, v6.8b\n"
2401 "sadalp v23.4s, v15.8h\n"
2402 "smull v15.8h, v3.8b, v7.8b\n"
2403 "smlal2 v12.8h, v2.16b, v6.16b\n"
2404 "smlal2 v13.8h, v2.16b, v7.16b\n"
2405 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
2406 "smlal2 v14.8h, v3.16b, v6.16b\n"
2407 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
2408 "smlal2 v15.8h, v3.16b, v7.16b\n"
2409 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
2410 "sadalp v26.4s, v12.8h\n"
2411 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
2412 "smull v12.8h, v0.8b, v4.8b\n"
2413 "sadalp v30.4s, v13.8h\n"
2414 "smull v13.8h, v1.8b, v4.8b\n"
2415 "sadalp v27.4s, v14.8h\n"
2416 "smull v14.8h, v0.8b, v5.8b\n"
2417 "sadalp v31.4s, v15.8h\n"
2418 "smull v15.8h, v1.8b, v5.8b\n"
2419 "smlal2 v12.8h, v0.16b, v4.16b\n"
2420 "smlal2 v13.8h, v1.16b, v4.16b\n"
2421 "smlal2 v14.8h, v0.16b, v5.16b\n"
2422 "smlal2 v15.8h, v1.16b, v5.16b\n"
2423
2424 "bne " GEMMLOWP_LABEL_LOOP "b\n"
2425
2426 GEMMLOWP_LABEL_AFTER_LOOP
2427 ":\n"
2428
2429 // Load accumulators from memory
2430 "ld1 {v8.16b}, [x0], #16\n"
2431 "ld1 {v9.16b}, [x0], #16\n"
2432 "ld1 {v10.16b}, [x0], #16\n"
2433 "ld1 {v11.16b}, [x0], #16\n"
2434 "mov x0, %[dst_ptr]\n"
2435
2436 // Do the remaining arithmetic for the 16 last levels of depths.
2437 // All the operands are already loaded.
2438 "sadalp v16.4s, v12.8h\n"
2439 "smull v12.8h, v0.8b, v6.8b\n"
2440 "sadalp v17.4s, v13.8h\n"
2441 "smull v13.8h, v0.8b, v7.8b\n"
2442 "sadalp v20.4s, v14.8h\n"
2443 "smull v14.8h, v1.8b, v6.8b\n"
2444 "sadalp v21.4s, v15.8h\n"
2445 "smull v15.8h, v1.8b, v7.8b\n"
2446 "smlal2 v12.8h, v0.16b, v6.16b\n"
2447 "smlal2 v13.8h, v0.16b, v7.16b\n"
2448 "smlal2 v14.8h, v1.16b, v6.16b\n"
2449 "smlal2 v15.8h, v1.16b, v7.16b\n"
2450 "sadalp v24.4s, v12.8h\n"
2451 "smull v12.8h, v2.8b, v4.8b\n"
2452 "sadalp v28.4s, v13.8h\n"
2453 "smull v13.8h, v3.8b, v4.8b\n"
2454 "sadalp v25.4s, v14.8h\n"
2455 "smull v14.8h, v2.8b, v5.8b\n"
2456 "sadalp v29.4s, v15.8h\n"
2457 "smull v15.8h, v3.8b, v5.8b\n"
2458 "smlal2 v12.8h, v2.16b, v4.16b\n"
2459 "smlal2 v13.8h, v3.16b, v4.16b\n"
2460 "smlal2 v14.8h, v2.16b, v5.16b\n"
2461 "smlal2 v15.8h, v3.16b, v5.16b\n"
2462 "sadalp v18.4s, v12.8h\n"
2463 "smull v12.8h, v2.8b, v6.8b\n"
2464 "sadalp v19.4s, v13.8h\n"
2465 "smull v13.8h, v2.8b, v7.8b\n"
2466 "sadalp v22.4s, v14.8h\n"
2467 "smull v14.8h, v3.8b, v6.8b\n"
2468 "sadalp v23.4s, v15.8h\n"
2469 "smull v15.8h, v3.8b, v7.8b\n"
2470 "smlal2 v12.8h, v2.16b, v6.16b\n"
2471 "smlal2 v13.8h, v2.16b, v7.16b\n"
2472 "smlal2 v14.8h, v3.16b, v6.16b\n"
2473 "smlal2 v15.8h, v3.16b, v7.16b\n"
2474 "sadalp v26.4s, v12.8h\n"
2475 "sadalp v30.4s, v13.8h\n"
2476 "sadalp v27.4s, v14.8h\n"
2477 "sadalp v31.4s, v15.8h\n"
2478
2479 // Reduce aggregators horizontally
2480 "addp v0.4s, v16.4s, v20.4s\n"
2481 "addp v1.4s, v17.4s, v21.4s\n"
2482 "addp v2.4s, v18.4s, v22.4s\n"
2483 "addp v3.4s, v19.4s, v23.4s\n"
2484 "addp v4.4s, v24.4s, v28.4s\n"
2485 "addp v5.4s, v25.4s, v29.4s\n"
2486 "addp v6.4s, v26.4s, v30.4s\n"
2487 "addp v7.4s, v27.4s, v31.4s\n"
2488
2489 "addp v12.4s, v0.4s, v4.4s\n"
2490 "addp v13.4s, v1.4s, v5.4s\n"
2491 "addp v14.4s, v2.4s, v6.4s\n"
2492 "addp v15.4s, v3.4s, v7.4s\n"
2493
2494 // Add to the accumulators loaded from memory
2495 "add v8.4s, v8.4s, v12.4s\n"
2496 "add v9.4s, v9.4s, v13.4s\n"
2497 "add v10.4s, v10.4s, v14.4s\n"
2498 "add v11.4s, v11.4s, v15.4s\n"
2499
2500 // Store accumulators back to memory
2501 "st1 {v8.16b}, [x0], #16\n"
2502 "st1 {v9.16b}, [x0], #16\n"
2503 "st1 {v10.16b}, [x0], #16\n"
2504 "st1 {v11.16b}, [x0], #16\n"
2505 : // outputs
2506 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2507 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth),
2508 [dst_col_stride] "+r"(dst_col_stride)
2509 : // inputs
2510 [start_depth] "r"(start_depth)
2511 : // clobbers
2512 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2513 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2514 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2515 "v28", "v29", "v30", "v31");
2516 }
2517 };
2518
2519 #ifdef __ARM_FEATURE_DOTPROD
2520 // Kernels utilizing the Armv8.2 Dot Product extension.
2521 //
2522 // The dot product instructions work by taking 4 consecutive 8-bit depth
2523 // values from each operand, multiplying the 4 pairs together and
2524 // accumulating all the results into the corresponding 32-bit accumulator
2525 // lane. As such, the operation is identical to a 32-bit instruction (like
2526 // FMLA used in SGEMM), except that 4 depth values are processed at a time
2527 // instead of 1.
2528
2529 // Thus, this first kernel is a carbon copy of
2530 // "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
2531 // performance for most processors) below with the opcode (fmla -> udot) and
2532 // types (float32 -> uint8/uint32) changed.
2533 //
2534 // A signed version of this kernel could be produced by replacing "udot"
2535 // with "sdot" - performance should be identical to this udot kernel.
2536 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
2537 typedef std::uint8_t OperandType;
2538 typedef std::uint32_t AccumulatorType;
2539 typedef KernelFormat<
2540 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
2541 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
2542 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct2543 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2544 AccumulatorType* accum_ptr, int depth) {
2545 asm volatile(
2546 // Load accumulators
2547 "mov x0, %[accum_ptr]\n"
2548 "ld1 {v8.4s}, [x0], #16\n"
2549 "ld1 {v16.4s}, [x0], #16\n"
2550 "ld1 {v24.4s}, [x0], #16\n"
2551 "ld1 {v9.4s}, [x0], #16\n"
2552 "ld1 {v17.4s}, [x0], #16\n"
2553 "ld1 {v25.4s}, [x0], #16\n"
2554 "ld1 {v10.4s}, [x0], #16\n"
2555 "ld1 {v18.4s}, [x0], #16\n"
2556 "ld1 {v26.4s}, [x0], #16\n"
2557 "ld1 {v11.4s}, [x0], #16\n"
2558 "ld1 {v19.4s}, [x0], #16\n"
2559 "ld1 {v27.4s}, [x0], #16\n"
2560 "ld1 {v12.4s}, [x0], #16\n"
2561 "ld1 {v20.4s}, [x0], #16\n"
2562 "ld1 {v28.4s}, [x0], #16\n"
2563 "ld1 {v13.4s}, [x0], #16\n"
2564 "ld1 {v21.4s}, [x0], #16\n"
2565 "ld1 {v29.4s}, [x0], #16\n"
2566 "ld1 {v14.4s}, [x0], #16\n"
2567 "ld1 {v22.4s}, [x0], #16\n"
2568 "ld1 {v30.4s}, [x0], #16\n"
2569 "ld1 {v15.4s}, [x0], #16\n"
2570 "ld1 {v23.4s}, [x0], #16\n"
2571 "ld1 {v31.4s}, [x0], #16\n"
2572
2573 // The start of the loop assumes first Rhs cell is already loaded, so
2574 // do it here for first iteration.
2575 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
2576
2577 // And the same for the first Lhs cell.
2578 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
2579
2580 GEMMLOWP_LABEL_LOOP
2581 ":\n"
2582
2583 // Start the MACs at the head of the loop - 1st cell from each side
2584 // already loaded.
2585 "udot v8.4s, v2.16b, v0.b[0]\n"
2586 "udot v9.4s, v2.16b, v0.b[1]\n"
2587 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
2588 "udot v10.4s, v2.16b, v0.b[2]\n"
2589 "udot v11.4s, v2.16b, v0.b[3]\n"
2590 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
2591 "udot v12.4s, v2.16b, v1.b[0]\n"
2592 "udot v13.4s, v2.16b, v1.b[1]\n"
2593 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
2594 "udot v14.4s, v2.16b, v1.b[2]\n"
2595 "udot v15.4s, v2.16b, v1.b[3]\n"
2596 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
2597 // for the next iteration early.
2598 "udot v16.4s, v3.16b, v0.b[0]\n"
2599 "udot v17.4s, v3.16b, v0.b[1]\n"
2600 "udot v18.4s, v3.16b, v0.b[2]\n"
2601 "udot v19.4s, v3.16b, v0.b[3]\n"
2602 "udot v20.4s, v3.16b, v1.b[0]\n"
2603 "udot v21.4s, v3.16b, v1.b[1]\n"
2604 "udot v22.4s, v3.16b, v1.b[2]\n"
2605 "udot v23.4s, v3.16b, v1.b[3]\n"
2606 "udot v24.4s, v4.16b, v0.b[0]\n"
2607 "udot v25.4s, v4.16b, v0.b[1]\n"
2608 "udot v26.4s, v4.16b, v0.b[2]\n"
2609 "udot v27.4s, v4.16b, v0.b[3]\n"
2610 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
2611 // load for the next iteration early.
2612 "udot v28.4s, v4.16b, v1.b[0]\n"
2613 "udot v29.4s, v4.16b, v1.b[1]\n"
2614
2615 // Loop. Decrement loop index (depth) by 4 as udot processes 4
2616 // depth values.
2617 "subs %w[depth], %w[depth], #4\n"
2618 "udot v30.4s, v4.16b, v1.b[2]\n"
2619 "udot v31.4s, v4.16b, v1.b[3]\n"
2620
2621 "bne " GEMMLOWP_LABEL_LOOP
2622 "b\n"
2623
2624 // Store accumulators
2625 "mov x0, %[accum_ptr]\n"
2626 "st1 {v8.16b}, [x0], #16\n"
2627 "st1 {v16.16b}, [x0], #16\n"
2628 "st1 {v24.16b}, [x0], #16\n"
2629 "st1 {v9.16b}, [x0], #16\n"
2630 "st1 {v17.16b}, [x0], #16\n"
2631 "st1 {v25.16b}, [x0], #16\n"
2632 "st1 {v10.16b}, [x0], #16\n"
2633 "st1 {v18.16b}, [x0], #16\n"
2634 "st1 {v26.16b}, [x0], #16\n"
2635 "st1 {v11.16b}, [x0], #16\n"
2636 "st1 {v19.16b}, [x0], #16\n"
2637 "st1 {v27.16b}, [x0], #16\n"
2638 "st1 {v12.16b}, [x0], #16\n"
2639 "st1 {v20.16b}, [x0], #16\n"
2640 "st1 {v28.16b}, [x0], #16\n"
2641 "st1 {v13.16b}, [x0], #16\n"
2642 "st1 {v21.16b}, [x0], #16\n"
2643 "st1 {v29.16b}, [x0], #16\n"
2644 "st1 {v14.16b}, [x0], #16\n"
2645 "st1 {v22.16b}, [x0], #16\n"
2646 "st1 {v30.16b}, [x0], #16\n"
2647 "st1 {v15.16b}, [x0], #16\n"
2648 "st1 {v23.16b}, [x0], #16\n"
2649 "st1 {v31.16b}, [x0], #16\n"
2650 : // outputs
2651 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2652 [depth] "+r"(depth)
2653 : // inputs
2654 [accum_ptr] "r"(accum_ptr)
2655 : // clobbers
2656 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2657 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2658 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2659 "v28", "v29", "v30", "v31");
2660 }
2661 };
2662
2663 // As above, except tuned for Cortex-A55r1.
2664 //
2665 // Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1
2666 // with the names changed.
2667 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
2668 typedef std::uint8_t OperandType;
2669 typedef std::uint32_t AccumulatorType;
2670 typedef KernelFormat<
2671 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
2672 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
2673 Format;
RunNEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r12674 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2675 AccumulatorType* accum_ptr, int depth) {
2676 asm volatile(
2677 // Load accumulators
2678 "mov x0, %[accum_ptr]\n"
2679 "ld1 {v8.4s}, [x0], #16\n"
2680 "ld1 {v16.4s}, [x0], #16\n"
2681 "ld1 {v24.4s}, [x0], #16\n"
2682 "ld1 {v9.4s}, [x0], #16\n"
2683 "ld1 {v17.4s}, [x0], #16\n"
2684 "ld1 {v25.4s}, [x0], #16\n"
2685 "ld1 {v10.4s}, [x0], #16\n"
2686 "ld1 {v18.4s}, [x0], #16\n"
2687 "ld1 {v26.4s}, [x0], #16\n"
2688 "ld1 {v11.4s}, [x0], #16\n"
2689 "ld1 {v19.4s}, [x0], #16\n"
2690 "ld1 {v27.4s}, [x0], #16\n"
2691 "ld1 {v12.4s}, [x0], #16\n"
2692 "ld1 {v20.4s}, [x0], #16\n"
2693 "ld1 {v28.4s}, [x0], #16\n"
2694 "ld1 {v13.4s}, [x0], #16\n"
2695 "ld1 {v21.4s}, [x0], #16\n"
2696 "ld1 {v29.4s}, [x0], #16\n"
2697 "ld1 {v14.4s}, [x0], #16\n"
2698 "ld1 {v22.4s}, [x0], #16\n"
2699 "ld1 {v30.4s}, [x0], #16\n"
2700 "ld1 {v15.4s}, [x0], #16\n"
2701 "ld1 {v23.4s}, [x0], #16\n"
2702 "ld1 {v31.4s}, [x0], #16\n"
2703
2704 // For details on how this kernel works, see the Float32 kernel below.
2705
2706 "ldr d0, [%[rhs_ptr]]\n"
2707 "ldr x18, [%[rhs_ptr], #8]\n"
2708
2709 "ldr q2, [%[lhs_ptr]]\n"
2710 "ldr q3, [%[lhs_ptr], #16]\n"
2711
2712 GEMMLOWP_LABEL_LOOP
2713 ":\n"
2714
2715 "udot v8.4s, v2.16b, v0.b[0]\n"
2716 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
2717 "udot v9.4s, v2.16b, v0.b[1]\n"
2718 "ins v0.d[1], x18\n" // Finish loading v0
2719 "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure.
2720 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
2721 "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure.
2722 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
2723 "udot v10.4s, v2.16b, v0.b[2]\n"
2724 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
2725 "udot v11.4s, v2.16b, v0.b[3]\n"
2726 "ins v1.d[1], x18\n" // Finish loading v1
2727 "udot v12.4s, v2.16b, v1.b[0]\n"
2728 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
2729 "udot v13.4s, v2.16b, v1.b[1]\n"
2730 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
2731 "udot v14.4s, v2.16b, v1.b[2]\n"
2732
2733 "udot v15.4s, v2.16b, v1.b[3]\n"
2734 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
2735 "udot v18.4s, v3.16b, v0.b[2]\n"
2736 "ins v4.d[1], x18\n" // Finish loading v4
2737 "udot v19.4s, v3.16b, v0.b[3]\n"
2738 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
2739 "udot v20.4s, v3.16b, v1.b[0]\n"
2740 "subs %w[depth], %w[depth], #4\n"
2741 "udot v21.4s, v3.16b, v1.b[1]\n"
2742
2743 "udot v22.4s, v3.16b, v1.b[2]\n"
2744
2745 "udot v23.4s, v3.16b, v1.b[3]\n"
2746 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
2747 "udot v24.4s, v4.16b, v0.b[0]\n"
2748 "ins v2.d[1], x18\n" // Finish loading next v2
2749 "udot v25.4s, v4.16b, v0.b[1]\n"
2750 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
2751 "udot v26.4s, v4.16b, v0.b[2]\n"
2752
2753 "udot v27.4s, v4.16b, v0.b[3]\n"
2754 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
2755 "udot v28.4s, v4.16b, v1.b[0]\n"
2756 "ins v3.d[1], x18\n" // Finish loading next v3
2757 "udot v29.4s, v4.16b, v1.b[1]\n"
2758 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
2759 "udot v30.4s, v4.16b, v1.b[2]\n"
2760
2761 "udot v31.4s, v4.16b, v1.b[3]\n"
2762 "bne " GEMMLOWP_LABEL_LOOP "b\n"
2763
2764 // Store accumulators
2765 "mov x0, %[accum_ptr]\n"
2766 "st1 {v8.4s}, [x0], #16\n"
2767 "st1 {v16.4s}, [x0], #16\n"
2768 "st1 {v24.4s}, [x0], #16\n"
2769 "st1 {v9.4s}, [x0], #16\n"
2770 "st1 {v17.4s}, [x0], #16\n"
2771 "st1 {v25.4s}, [x0], #16\n"
2772 "st1 {v10.4s}, [x0], #16\n"
2773 "st1 {v18.4s}, [x0], #16\n"
2774 "st1 {v26.4s}, [x0], #16\n"
2775 "st1 {v11.4s}, [x0], #16\n"
2776 "st1 {v19.4s}, [x0], #16\n"
2777 "st1 {v27.4s}, [x0], #16\n"
2778 "st1 {v12.4s}, [x0], #16\n"
2779 "st1 {v20.4s}, [x0], #16\n"
2780 "st1 {v28.4s}, [x0], #16\n"
2781 "st1 {v13.4s}, [x0], #16\n"
2782 "st1 {v21.4s}, [x0], #16\n"
2783 "st1 {v29.4s}, [x0], #16\n"
2784 "st1 {v14.4s}, [x0], #16\n"
2785 "st1 {v22.4s}, [x0], #16\n"
2786 "st1 {v30.4s}, [x0], #16\n"
2787 "st1 {v15.4s}, [x0], #16\n"
2788 "st1 {v23.4s}, [x0], #16\n"
2789 "st1 {v31.4s}, [x0], #16\n"
2790 : // outputs
2791 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2792 [depth] "+r"(depth)
2793 : // inputs
2794 [accum_ptr] "r"(accum_ptr)
2795 : // clobbers
2796 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
2797 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
2798 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
2799 "v27", "v28", "v29", "v30", "v31");
2800 }
2801 };
2802 #endif // __ARM_FEATURE_DOTPROD
2803
2804 // We don't actually use int32*int32 in production. This is just an
2805 // experiment to help dissociate the effect of integer-vs-float, from the
2806 // effect of operands width.
2807 struct NEON_64bit_GEMM_Int32_WithScalar {
2808 typedef std::int32_t OperandType;
2809 typedef std::int32_t AccumulatorType;
2810 typedef KernelFormat<
2811 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
2812 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
2813 Format;
RunNEON_64bit_GEMM_Int32_WithScalar2814 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2815 AccumulatorType* accum_ptr, int depth) {
2816 asm volatile(
2817 // Load accumulators
2818 "mov x0, %[accum_ptr]\n"
2819 "ld1 {v8.16b}, [x0], #16\n"
2820 "ld1 {v16.16b}, [x0], #16\n"
2821 "ld1 {v24.16b}, [x0], #16\n"
2822 "ld1 {v9.16b}, [x0], #16\n"
2823 "ld1 {v17.16b}, [x0], #16\n"
2824 "ld1 {v25.16b}, [x0], #16\n"
2825 "ld1 {v10.16b}, [x0], #16\n"
2826 "ld1 {v18.16b}, [x0], #16\n"
2827 "ld1 {v26.16b}, [x0], #16\n"
2828 "ld1 {v11.16b}, [x0], #16\n"
2829 "ld1 {v19.16b}, [x0], #16\n"
2830 "ld1 {v27.16b}, [x0], #16\n"
2831 "ld1 {v12.16b}, [x0], #16\n"
2832 "ld1 {v20.16b}, [x0], #16\n"
2833 "ld1 {v28.16b}, [x0], #16\n"
2834 "ld1 {v13.16b}, [x0], #16\n"
2835 "ld1 {v21.16b}, [x0], #16\n"
2836 "ld1 {v29.16b}, [x0], #16\n"
2837 "ld1 {v14.16b}, [x0], #16\n"
2838 "ld1 {v22.16b}, [x0], #16\n"
2839 "ld1 {v30.16b}, [x0], #16\n"
2840 "ld1 {v15.16b}, [x0], #16\n"
2841 "ld1 {v23.16b}, [x0], #16\n"
2842 "ld1 {v31.16b}, [x0], #16\n"
2843
2844 GEMMLOWP_LABEL_LOOP
2845 ":\n"
2846
2847 // Load 2 Rhs cell of size 1x4 each
2848 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
2849 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n"
2850
2851 // Load 3 Lhs cells of size 4x1 each
2852 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
2853 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
2854 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
2855
2856 // Multiply-accumulate
2857 "mla v8.4s, v2.4s, v0.s[0]\n"
2858 "mla v9.4s, v2.4s, v0.s[1]\n"
2859 "mla v10.4s, v2.4s, v0.s[2]\n"
2860 "mla v11.4s, v2.4s, v0.s[3]\n"
2861 "mla v12.4s, v2.4s, v1.s[0]\n"
2862 "mla v13.4s, v2.4s, v1.s[1]\n"
2863 "mla v14.4s, v2.4s, v1.s[2]\n"
2864 "mla v15.4s, v2.4s, v1.s[3]\n"
2865 "mla v16.4s, v3.4s, v0.s[0]\n"
2866 "mla v17.4s, v3.4s, v0.s[1]\n"
2867 "mla v18.4s, v3.4s, v0.s[2]\n"
2868 "mla v19.4s, v3.4s, v0.s[3]\n"
2869 "mla v20.4s, v3.4s, v1.s[0]\n"
2870 "mla v21.4s, v3.4s, v1.s[1]\n"
2871 "mla v22.4s, v3.4s, v1.s[2]\n"
2872 "mla v23.4s, v3.4s, v1.s[3]\n"
2873 "mla v24.4s, v4.4s, v0.s[0]\n"
2874 "mla v25.4s, v4.4s, v0.s[1]\n"
2875 "mla v26.4s, v4.4s, v0.s[2]\n"
2876 "mla v27.4s, v4.4s, v0.s[3]\n"
2877 "mla v28.4s, v4.4s, v1.s[0]\n"
2878 "mla v29.4s, v4.4s, v1.s[1]\n"
2879 "mla v30.4s, v4.4s, v1.s[2]\n"
2880 "mla v31.4s, v4.4s, v1.s[3]\n"
2881
2882 // Loop. Decrement loop index (depth) by 1, since we just handled 1
2883 // level of depth.
2884 "subs %w[depth], %w[depth], #1\n"
2885 "bne " GEMMLOWP_LABEL_LOOP
2886 "b\n"
2887
2888 // Store accumulators
2889 "mov x0, %[accum_ptr]\n"
2890 "st1 {v8.16b}, [x0], #16\n"
2891 "st1 {v16.16b}, [x0], #16\n"
2892 "st1 {v24.16b}, [x0], #16\n"
2893 "st1 {v9.16b}, [x0], #16\n"
2894 "st1 {v17.16b}, [x0], #16\n"
2895 "st1 {v25.16b}, [x0], #16\n"
2896 "st1 {v10.16b}, [x0], #16\n"
2897 "st1 {v18.16b}, [x0], #16\n"
2898 "st1 {v26.16b}, [x0], #16\n"
2899 "st1 {v11.16b}, [x0], #16\n"
2900 "st1 {v19.16b}, [x0], #16\n"
2901 "st1 {v27.16b}, [x0], #16\n"
2902 "st1 {v12.16b}, [x0], #16\n"
2903 "st1 {v20.16b}, [x0], #16\n"
2904 "st1 {v28.16b}, [x0], #16\n"
2905 "st1 {v13.16b}, [x0], #16\n"
2906 "st1 {v21.16b}, [x0], #16\n"
2907 "st1 {v29.16b}, [x0], #16\n"
2908 "st1 {v14.16b}, [x0], #16\n"
2909 "st1 {v22.16b}, [x0], #16\n"
2910 "st1 {v30.16b}, [x0], #16\n"
2911 "st1 {v15.16b}, [x0], #16\n"
2912 "st1 {v23.16b}, [x0], #16\n"
2913 "st1 {v31.16b}, [x0], #16\n"
2914 : // outputs
2915 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2916 [depth] "+r"(depth)
2917 : // inputs
2918 [accum_ptr] "r"(accum_ptr)
2919 : // clobbers
2920 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
2921 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
2922 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
2923 "v28", "v29", "v30", "v31");
2924 }
2925 };
2926
2927 // Not very efficient kernel, just an experiment to see what we can do
2928 // without using NEON multiply-with-scalar instructions.
2929 struct NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar {
2930 typedef float OperandType;
2931 typedef float AccumulatorType;
2932 typedef KernelFormat<
2933 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
2934 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
2935 Format;
RunNEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar2936 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
2937 AccumulatorType* accum_ptr, int depth) {
2938 asm volatile(
2939 // Load accumulators
2940 "mov x0, %[accum_ptr]\n"
2941 "ld1 {v8.16b}, [x0], #16\n"
2942 "ld1 {v16.16b}, [x0], #16\n"
2943 "ld1 {v24.16b}, [x0], #16\n"
2944 "ld1 {v9.16b}, [x0], #16\n"
2945 "ld1 {v17.16b}, [x0], #16\n"
2946 "ld1 {v25.16b}, [x0], #16\n"
2947 "ld1 {v10.16b}, [x0], #16\n"
2948 "ld1 {v18.16b}, [x0], #16\n"
2949 "ld1 {v26.16b}, [x0], #16\n"
2950 "ld1 {v11.16b}, [x0], #16\n"
2951 "ld1 {v19.16b}, [x0], #16\n"
2952 "ld1 {v27.16b}, [x0], #16\n"
2953 "ld1 {v12.16b}, [x0], #16\n"
2954 "ld1 {v20.16b}, [x0], #16\n"
2955 "ld1 {v28.16b}, [x0], #16\n"
2956 "ld1 {v13.16b}, [x0], #16\n"
2957 "ld1 {v21.16b}, [x0], #16\n"
2958 "ld1 {v29.16b}, [x0], #16\n"
2959 "ld1 {v14.16b}, [x0], #16\n"
2960 "ld1 {v22.16b}, [x0], #16\n"
2961 "ld1 {v30.16b}, [x0], #16\n"
2962 "ld1 {v15.16b}, [x0], #16\n"
2963 "ld1 {v23.16b}, [x0], #16\n"
2964 "ld1 {v31.16b}, [x0], #16\n"
2965
2966 GEMMLOWP_LABEL_LOOP
2967 ":\n"
2968
2969 // Load 2 Rhs cell of size 1x4 each
2970 "ld1 {v5.4s}, [%[rhs_ptr]], #16\n"
2971 "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
2972
2973 // Load 3 Lhs cells of size 4x1 each
2974 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
2975 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
2976 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
2977
2978 // Multiply-accumulate
2979 "dup v0.4s, v5.s[0]\n"
2980 "dup v1.4s, v5.s[1]\n"
2981 "fmla v8.4s, v2.4s, v0.4s\n"
2982 "fmla v16.4s, v3.4s, v0.4s\n"
2983 "fmla v24.4s, v4.4s, v0.4s\n"
2984 "fmla v9.4s, v2.4s, v1.4s\n"
2985 "fmla v17.4s, v3.4s, v1.4s\n"
2986 "fmla v25.4s, v4.4s, v1.4s\n"
2987 "dup v0.4s, v5.s[2]\n"
2988 "dup v1.4s, v5.s[3]\n"
2989 "fmla v10.4s, v2.4s, v0.4s\n"
2990 "fmla v18.4s, v3.4s, v0.4s\n"
2991 "fmla v26.4s, v4.4s, v0.4s\n"
2992 "fmla v11.4s, v2.4s, v1.4s\n"
2993 "fmla v19.4s, v3.4s, v1.4s\n"
2994 "fmla v27.4s, v4.4s, v1.4s\n"
2995 "dup v0.4s, v6.s[0]\n"
2996 "dup v1.4s, v6.s[1]\n"
2997 "fmla v12.4s, v2.4s, v0.4s\n"
2998 "fmla v20.4s, v3.4s, v0.4s\n"
2999 "fmla v28.4s, v4.4s, v0.4s\n"
3000 "fmla v13.4s, v2.4s, v1.4s\n"
3001 "fmla v21.4s, v3.4s, v1.4s\n"
3002 "fmla v29.4s, v4.4s, v1.4s\n"
3003 "dup v0.4s, v6.s[2]\n"
3004 "dup v1.4s, v6.s[3]\n"
3005 "fmla v14.4s, v2.4s, v0.4s\n"
3006 "fmla v22.4s, v3.4s, v0.4s\n"
3007 "fmla v30.4s, v4.4s, v0.4s\n"
3008 "fmla v15.4s, v2.4s, v1.4s\n"
3009 "fmla v23.4s, v3.4s, v1.4s\n"
3010 "fmla v31.4s, v4.4s, v1.4s\n"
3011
3012 // Loop. Decrement loop index (depth) by 1, since we just handled 1
3013 // level of depth.
3014 "subs %w[depth], %w[depth], #1\n"
3015 "bne " GEMMLOWP_LABEL_LOOP
3016 "b\n"
3017
3018 // Store accumulators
3019 "mov x0, %[accum_ptr]\n"
3020 "st1 {v8.16b}, [x0], #16\n"
3021 "st1 {v16.16b}, [x0], #16\n"
3022 "st1 {v24.16b}, [x0], #16\n"
3023 "st1 {v9.16b}, [x0], #16\n"
3024 "st1 {v17.16b}, [x0], #16\n"
3025 "st1 {v25.16b}, [x0], #16\n"
3026 "st1 {v10.16b}, [x0], #16\n"
3027 "st1 {v18.16b}, [x0], #16\n"
3028 "st1 {v26.16b}, [x0], #16\n"
3029 "st1 {v11.16b}, [x0], #16\n"
3030 "st1 {v19.16b}, [x0], #16\n"
3031 "st1 {v27.16b}, [x0], #16\n"
3032 "st1 {v12.16b}, [x0], #16\n"
3033 "st1 {v20.16b}, [x0], #16\n"
3034 "st1 {v28.16b}, [x0], #16\n"
3035 "st1 {v13.16b}, [x0], #16\n"
3036 "st1 {v21.16b}, [x0], #16\n"
3037 "st1 {v29.16b}, [x0], #16\n"
3038 "st1 {v14.16b}, [x0], #16\n"
3039 "st1 {v22.16b}, [x0], #16\n"
3040 "st1 {v30.16b}, [x0], #16\n"
3041 "st1 {v15.16b}, [x0], #16\n"
3042 "st1 {v23.16b}, [x0], #16\n"
3043 "st1 {v31.16b}, [x0], #16\n"
3044 : // outputs
3045 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3046 [depth] "+r"(depth)
3047 : // inputs
3048 [accum_ptr] "r"(accum_ptr)
3049 : // clobbers
3050 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3051 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3052 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3053 "v28", "v29", "v30", "v31");
3054 }
3055 };
3056
3057 // This is the "most natural" kernel, using NEON multiply-with-scalar
3058 // instructions.
3059 struct NEON_64bit_GEMM_Float32_WithScalar {
3060 typedef float OperandType;
3061 typedef float AccumulatorType;
3062 typedef KernelFormat<
3063 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3064 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3065 Format;
RunNEON_64bit_GEMM_Float32_WithScalar3066 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3067 AccumulatorType* accum_ptr, int depth) {
3068 asm volatile(
3069 // Load accumulators
3070 "mov x0, %[accum_ptr]\n"
3071 "ld1 {v8.16b}, [x0], #16\n"
3072 "ld1 {v16.16b}, [x0], #16\n"
3073 "ld1 {v24.16b}, [x0], #16\n"
3074 "ld1 {v9.16b}, [x0], #16\n"
3075 "ld1 {v17.16b}, [x0], #16\n"
3076 "ld1 {v25.16b}, [x0], #16\n"
3077 "ld1 {v10.16b}, [x0], #16\n"
3078 "ld1 {v18.16b}, [x0], #16\n"
3079 "ld1 {v26.16b}, [x0], #16\n"
3080 "ld1 {v11.16b}, [x0], #16\n"
3081 "ld1 {v19.16b}, [x0], #16\n"
3082 "ld1 {v27.16b}, [x0], #16\n"
3083 "ld1 {v12.16b}, [x0], #16\n"
3084 "ld1 {v20.16b}, [x0], #16\n"
3085 "ld1 {v28.16b}, [x0], #16\n"
3086 "ld1 {v13.16b}, [x0], #16\n"
3087 "ld1 {v21.16b}, [x0], #16\n"
3088 "ld1 {v29.16b}, [x0], #16\n"
3089 "ld1 {v14.16b}, [x0], #16\n"
3090 "ld1 {v22.16b}, [x0], #16\n"
3091 "ld1 {v30.16b}, [x0], #16\n"
3092 "ld1 {v15.16b}, [x0], #16\n"
3093 "ld1 {v23.16b}, [x0], #16\n"
3094 "ld1 {v31.16b}, [x0], #16\n"
3095
3096 GEMMLOWP_LABEL_LOOP
3097 ":\n"
3098
3099 // Load 2 Rhs cell of size 1x4 each
3100 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
3101 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n"
3102
3103 // Load 3 Lhs cells of size 4x1 each
3104 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
3105 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n"
3106 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
3107
3108 // Multiply-accumulate
3109 "fmla v8.4s, v2.4s, v0.s[0]\n"
3110 "fmla v9.4s, v2.4s, v0.s[1]\n"
3111 "fmla v10.4s, v2.4s, v0.s[2]\n"
3112 "fmla v11.4s, v2.4s, v0.s[3]\n"
3113 "fmla v12.4s, v2.4s, v1.s[0]\n"
3114 "fmla v13.4s, v2.4s, v1.s[1]\n"
3115 "fmla v14.4s, v2.4s, v1.s[2]\n"
3116 "fmla v15.4s, v2.4s, v1.s[3]\n"
3117 "fmla v16.4s, v3.4s, v0.s[0]\n"
3118 "fmla v17.4s, v3.4s, v0.s[1]\n"
3119 "fmla v18.4s, v3.4s, v0.s[2]\n"
3120 "fmla v19.4s, v3.4s, v0.s[3]\n"
3121 "fmla v20.4s, v3.4s, v1.s[0]\n"
3122 "fmla v21.4s, v3.4s, v1.s[1]\n"
3123 "fmla v22.4s, v3.4s, v1.s[2]\n"
3124 "fmla v23.4s, v3.4s, v1.s[3]\n"
3125 "fmla v24.4s, v4.4s, v0.s[0]\n"
3126 "fmla v25.4s, v4.4s, v0.s[1]\n"
3127 "fmla v26.4s, v4.4s, v0.s[2]\n"
3128 "fmla v27.4s, v4.4s, v0.s[3]\n"
3129 "fmla v28.4s, v4.4s, v1.s[0]\n"
3130 "fmla v29.4s, v4.4s, v1.s[1]\n"
3131 "fmla v30.4s, v4.4s, v1.s[2]\n"
3132 "fmla v31.4s, v4.4s, v1.s[3]\n"
3133
3134 // Loop. Decrement loop index (depth) by 1, since we just handled 1
3135 // level of depth.
3136 "subs %w[depth], %w[depth], #1\n"
3137 "bne " GEMMLOWP_LABEL_LOOP
3138 "b\n"
3139
3140 // Store accumulators
3141 "mov x0, %[accum_ptr]\n"
3142 "st1 {v8.16b}, [x0], #16\n"
3143 "st1 {v16.16b}, [x0], #16\n"
3144 "st1 {v24.16b}, [x0], #16\n"
3145 "st1 {v9.16b}, [x0], #16\n"
3146 "st1 {v17.16b}, [x0], #16\n"
3147 "st1 {v25.16b}, [x0], #16\n"
3148 "st1 {v10.16b}, [x0], #16\n"
3149 "st1 {v18.16b}, [x0], #16\n"
3150 "st1 {v26.16b}, [x0], #16\n"
3151 "st1 {v11.16b}, [x0], #16\n"
3152 "st1 {v19.16b}, [x0], #16\n"
3153 "st1 {v27.16b}, [x0], #16\n"
3154 "st1 {v12.16b}, [x0], #16\n"
3155 "st1 {v20.16b}, [x0], #16\n"
3156 "st1 {v28.16b}, [x0], #16\n"
3157 "st1 {v13.16b}, [x0], #16\n"
3158 "st1 {v21.16b}, [x0], #16\n"
3159 "st1 {v29.16b}, [x0], #16\n"
3160 "st1 {v14.16b}, [x0], #16\n"
3161 "st1 {v22.16b}, [x0], #16\n"
3162 "st1 {v30.16b}, [x0], #16\n"
3163 "st1 {v15.16b}, [x0], #16\n"
3164 "st1 {v23.16b}, [x0], #16\n"
3165 "st1 {v31.16b}, [x0], #16\n"
3166 : // outputs
3167 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3168 [depth] "+r"(depth)
3169 : // inputs
3170 [accum_ptr] "r"(accum_ptr)
3171 : // clobbers
3172 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3173 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3174 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3175 "v28", "v29", "v30", "v31");
3176 }
3177 };
3178
3179 // Faster kernel contributed by ARM. Tuned for A57.
3180 struct NEON_64bit_GEMM_Float32_WithScalar_A57 {
3181 typedef float OperandType;
3182 typedef float AccumulatorType;
3183 typedef KernelFormat<
3184 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3185 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3186 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A573187 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3188 AccumulatorType* accum_ptr, int depth) {
3189 asm volatile(
3190 // Load accumulators
3191 "mov x0, %[accum_ptr]\n"
3192 "ld1 {v8.16b}, [x0], #16\n"
3193 "ld1 {v16.16b}, [x0], #16\n"
3194 "ld1 {v24.16b}, [x0], #16\n"
3195 "ld1 {v9.16b}, [x0], #16\n"
3196 "ld1 {v17.16b}, [x0], #16\n"
3197 "ld1 {v25.16b}, [x0], #16\n"
3198 "ld1 {v10.16b}, [x0], #16\n"
3199 "ld1 {v18.16b}, [x0], #16\n"
3200 "ld1 {v26.16b}, [x0], #16\n"
3201 "ld1 {v11.16b}, [x0], #16\n"
3202 "ld1 {v19.16b}, [x0], #16\n"
3203 "ld1 {v27.16b}, [x0], #16\n"
3204 "ld1 {v12.16b}, [x0], #16\n"
3205 "ld1 {v20.16b}, [x0], #16\n"
3206 "ld1 {v28.16b}, [x0], #16\n"
3207 "ld1 {v13.16b}, [x0], #16\n"
3208 "ld1 {v21.16b}, [x0], #16\n"
3209 "ld1 {v29.16b}, [x0], #16\n"
3210 "ld1 {v14.16b}, [x0], #16\n"
3211 "ld1 {v22.16b}, [x0], #16\n"
3212 "ld1 {v30.16b}, [x0], #16\n"
3213 "ld1 {v15.16b}, [x0], #16\n"
3214 "ld1 {v23.16b}, [x0], #16\n"
3215 "ld1 {v31.16b}, [x0], #16\n"
3216
3217 // The start of the loop assumes first Rhs cell is already loaded, so
3218 // do it here for first iteration.
3219 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n"
3220
3221 // And the same for the first Lhs cell.
3222 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n"
3223
3224 GEMMLOWP_LABEL_LOOP
3225 ":\n"
3226
3227 // Start the MACs at the head of the loop - 1st cell from each side
3228 // already loaded.
3229 "fmla v8.4s, v2.4s, v0.s[0]\n"
3230 "fmla v9.4s, v2.4s, v0.s[1]\n"
3231 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
3232 "fmla v10.4s, v2.4s, v0.s[2]\n"
3233 "fmla v11.4s, v2.4s, v0.s[3]\n"
3234 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
3235 "fmla v12.4s, v2.4s, v1.s[0]\n"
3236 "fmla v13.4s, v2.4s, v1.s[1]\n"
3237 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
3238 "fmla v14.4s, v2.4s, v1.s[2]\n"
3239 "fmla v15.4s, v2.4s, v1.s[3]\n"
3240 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
3241 // for the next iteration early.
3242 "fmla v16.4s, v3.4s, v0.s[0]\n"
3243 "fmla v17.4s, v3.4s, v0.s[1]\n"
3244 "fmla v18.4s, v3.4s, v0.s[2]\n"
3245 "fmla v19.4s, v3.4s, v0.s[3]\n"
3246 "fmla v20.4s, v3.4s, v1.s[0]\n"
3247 "fmla v21.4s, v3.4s, v1.s[1]\n"
3248 "fmla v22.4s, v3.4s, v1.s[2]\n"
3249 "fmla v23.4s, v3.4s, v1.s[3]\n"
3250 "fmla v24.4s, v4.4s, v0.s[0]\n"
3251 "fmla v25.4s, v4.4s, v0.s[1]\n"
3252 "fmla v26.4s, v4.4s, v0.s[2]\n"
3253 "fmla v27.4s, v4.4s, v0.s[3]\n"
3254 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
3255 // load for the next iteration
3256 // early.
3257 "fmla v28.4s, v4.4s, v1.s[0]\n"
3258 "fmla v29.4s, v4.4s, v1.s[1]\n"
3259 // Loop. Decrement loop index (depth) by 1, since we just handled
3260 // 1 level of depth. Do this a bit before the end of the loop for
3261 // better dispatch on A57.
3262 "subs %w[depth], %w[depth], #1\n"
3263 "fmla v30.4s, v4.4s, v1.s[2]\n"
3264 "fmla v31.4s, v4.4s, v1.s[3]\n"
3265
3266 "bne " GEMMLOWP_LABEL_LOOP
3267 "b\n"
3268
3269 // Store accumulators
3270 "mov x0, %[accum_ptr]\n"
3271 "st1 {v8.16b}, [x0], #16\n"
3272 "st1 {v16.16b}, [x0], #16\n"
3273 "st1 {v24.16b}, [x0], #16\n"
3274 "st1 {v9.16b}, [x0], #16\n"
3275 "st1 {v17.16b}, [x0], #16\n"
3276 "st1 {v25.16b}, [x0], #16\n"
3277 "st1 {v10.16b}, [x0], #16\n"
3278 "st1 {v18.16b}, [x0], #16\n"
3279 "st1 {v26.16b}, [x0], #16\n"
3280 "st1 {v11.16b}, [x0], #16\n"
3281 "st1 {v19.16b}, [x0], #16\n"
3282 "st1 {v27.16b}, [x0], #16\n"
3283 "st1 {v12.16b}, [x0], #16\n"
3284 "st1 {v20.16b}, [x0], #16\n"
3285 "st1 {v28.16b}, [x0], #16\n"
3286 "st1 {v13.16b}, [x0], #16\n"
3287 "st1 {v21.16b}, [x0], #16\n"
3288 "st1 {v29.16b}, [x0], #16\n"
3289 "st1 {v14.16b}, [x0], #16\n"
3290 "st1 {v22.16b}, [x0], #16\n"
3291 "st1 {v30.16b}, [x0], #16\n"
3292 "st1 {v15.16b}, [x0], #16\n"
3293 "st1 {v23.16b}, [x0], #16\n"
3294 "st1 {v31.16b}, [x0], #16\n"
3295 : // outputs
3296 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3297 [depth] "+r"(depth)
3298 : // inputs
3299 [accum_ptr] "r"(accum_ptr)
3300 : // clobbers
3301 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
3302 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
3303 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
3304 "v28", "v29", "v30", "v31");
3305 }
3306 };
3307
3308 #ifndef __APPLE__
3309 // Faster kernel contributed by ARM. Tuned for A53.
3310 struct NEON_64bit_GEMM_Float32_WithScalar_A53 {
3311 typedef float OperandType;
3312 typedef float AccumulatorType;
3313 typedef KernelFormat<
3314 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3315 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3316 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A533317 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3318 AccumulatorType* accum_ptr, int depth) {
3319 asm volatile(
3320 // Load accumulators
3321 "mov x0, %[accum_ptr]\n"
3322 "ld1 {v8.16b}, [x0], #16\n"
3323 "ld1 {v16.16b}, [x0], #16\n"
3324 "ld1 {v24.16b}, [x0], #16\n"
3325 "ld1 {v9.16b}, [x0], #16\n"
3326 "ld1 {v17.16b}, [x0], #16\n"
3327 "ld1 {v25.16b}, [x0], #16\n"
3328 "ld1 {v10.16b}, [x0], #16\n"
3329 "ld1 {v18.16b}, [x0], #16\n"
3330 "ld1 {v26.16b}, [x0], #16\n"
3331 "ld1 {v11.16b}, [x0], #16\n"
3332 "ld1 {v19.16b}, [x0], #16\n"
3333 "ld1 {v27.16b}, [x0], #16\n"
3334 "ld1 {v12.16b}, [x0], #16\n"
3335 "ld1 {v20.16b}, [x0], #16\n"
3336 "ld1 {v28.16b}, [x0], #16\n"
3337 "ld1 {v13.16b}, [x0], #16\n"
3338 "ld1 {v21.16b}, [x0], #16\n"
3339 "ld1 {v29.16b}, [x0], #16\n"
3340 "ld1 {v14.16b}, [x0], #16\n"
3341 "ld1 {v22.16b}, [x0], #16\n"
3342 "ld1 {v30.16b}, [x0], #16\n"
3343 "ld1 {v15.16b}, [x0], #16\n"
3344 "ld1 {v23.16b}, [x0], #16\n"
3345 "ld1 {v31.16b}, [x0], #16\n"
3346
3347 // For A53, a very different-looking loop is needed.
3348 //
3349 // The main reason for this is that on A53 128-bit loads take two
3350 // cycles during which no dual issue can occur. Doing two separate
3351 // 64-bit loads avoids this issue - they each take one cycle and are
3352 // able to dual issue. Since vector register loads don't dual issue
3353 // with FMLA, we load half the register as normal and the other half
3354 // into an integer register. This second half can then be moved into
3355 // place later with an INS instruction - which will dual issue with a
3356 // later FP load.
3357 //
3358 // For this kernel there are approximately 3 times as many multiplies
3359 // as loads, so it makes sense to structure the loop into blocks of 4
3360 // cycles, with 1 dedicated "load cycle" and 3 "multiply cycles" per
3361 // block. Strictly preserving this structure with NOPs where no load
3362 // is needed seems to result in higher performance.
3363 //
3364 // Choice of x18 to store the upper halves on their way into the
3365 // vector registers is arbitrary. Added to the clobber list so that
3366 // the compiler will make it available.
3367 //
3368 //
3369 // At the start of the loop, it is assumed that v0 is "half loaded" -
3370 // bottom half in place in d0 and the upper half in x18 ready to
3371 // insert. So set that up here for the first iteration:
3372 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
3373 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
3374 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Separate increment (needed as
3375 // there is no operation to load at
3376 // reg + 8 but then increment reg
3377 // by 16).
3378
3379 // v2 should be fully loaded - as it's outside the loop proper it's fine
3380 // to use a 128-bit load here.
3381 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // first Lhs cell
3382
3383 GEMMLOWP_LABEL_LOOP
3384 ":\n"
3385
3386 // First block of four cycles. Multplies all require v2 and v0; v2 is
3387 // loaded earlier and v0 is half loaded and completed in the load
3388 // cycle at the start.
3389 "ldr d1, [%[rhs_ptr]]\n" // "load" cycle - loading bottom half of v1
3390 // (second Rhs cell).
3391 "ins v0.d[1], x18\n" // "load" cycle - moving the upper half of v0 into
3392 // place.
3393 "fmla v8.4s, v2.4s, v0.s[0]\n" // "fmla" cycle 1 - first multiply.
3394 "ldr x18, [%[rhs_ptr], #8]\n" // "fmla" cycle 1 - load upper half of v1
3395 // into x18.
3396 "fmla v9.4s, v2.4s, v0.s[1]\n" // "fmla" cycle 2 - second multiply
3397 "add %[rhs_ptr], %[rhs_ptr], #16\n" // "fmla" cycle 2 - increment Rhs
3398 // pointer (if needed)
3399 "fmla v10.4s, v2.4s, v0.s[2]\n" // "fmla" cycle 3 - third multiply. No
3400 // more work to dual issue.
3401
3402 // Second block. Start loading v3 (second Lhs cell), finish loading v1.
3403 "ldr d3, [%[lhs_ptr]]\n"
3404 "ins v1.d[1], x18\n" // v1 ready here.
3405 "fmla v11.4s, v2.4s, v0.s[3]\n"
3406 "ldr x18, [%[lhs_ptr], #8]\n"
3407 "fmla v12.4s, v2.4s, v1.s[0]\n" // First use of v1.
3408 "add %[lhs_ptr], %[lhs_ptr], #16\n"
3409 "fmla v13.4s, v2.4s, v1.s[1]\n"
3410
3411 // Third block. Start loading v4 (third Lhs cell), finish loading v3.
3412 "ldr d4, [%[lhs_ptr]]\n"
3413 "ins v3.d[1], x18\n" // v3 ready here.
3414 "fmla v14.4s, v2.4s, v1.s[2]\n"
3415 "ldr x18, [%[lhs_ptr], #8]\n"
3416 "fmla v15.4s, v2.4s, v1.s[3]\n"
3417 "add %[lhs_ptr], %[lhs_ptr], #16\n"
3418 "fmla v16.4s, v3.4s, v0.s[0]\n" // First use of v3.
3419
3420 // Fourth block. v2 (first Lhs cell) is now finished with, so start
3421 // loading value for next iteration. Finish loading v4.
3422 "ldr d2, [%[lhs_ptr]]\n"
3423 "ins v4.d[1], x18\n" // v4 ready here.
3424 "fmla v17.4s, v3.4s, v0.s[1]\n"
3425 "ldr x18, [%[lhs_ptr], #8]\n"
3426 "fmla v18.4s, v3.4s, v0.s[2]\n"
3427 "add %[lhs_ptr], %[lhs_ptr], #16\n"
3428 "fmla v19.4s, v3.4s, v0.s[3]\n"
3429
3430 // Fifth block, finish loading v2. No new load to start as the other
3431 // registers are all still live.
3432 "ins v2.d[1], x18\n"
3433 "fmla v20.4s, v3.4s, v1.s[0]\n"
3434 "fmla v21.4s, v3.4s, v1.s[1]\n"
3435 "fmla v22.4s, v3.4s, v1.s[2]\n"
3436
3437 // Sixth block, nothing to load. 2 nops needed as a single nop would
3438 // dual issue with the FMLA and break the timing.
3439 "nop\n"
3440 "nop\n"
3441 "fmla v23.4s, v3.4s, v1.s[3]\n"
3442 "fmla v24.4s, v4.4s, v0.s[0]\n" // First use of v4.
3443 "fmla v25.4s, v4.4s, v0.s[1]\n"
3444
3445 // Seventh block, nothing to load. Decrement the loop counter in this
3446 // block as the last block is very full.
3447 "nop\n"
3448 "nop\n"
3449 "fmla v26.4s, v4.4s, v0.s[2]\n"
3450 "subs %w[depth], %w[depth], #1\n"
3451 "fmla v27.4s, v4.4s, v0.s[3]\n"
3452 "fmla v28.4s, v4.4s, v1.s[0]\n"
3453
3454 // Eighth block - start loading v0 for next iteration.
3455 "ldr d0, [%[rhs_ptr]]\n"
3456 "fmla v29.4s, v4.4s, v1.s[1]\n"
3457 "ldr x18, [%[rhs_ptr], #8]\n"
3458 "fmla v30.4s, v4.4s, v1.s[2]\n"
3459 "add %[rhs_ptr], %[rhs_ptr], #16\n"
3460 "fmla v31.4s, v4.4s, v1.s[3]\n"
3461
3462 // Loop branch. This will dual issue in fmla cycle 3 of the 8th block.
3463 "bne " GEMMLOWP_LABEL_LOOP
3464 "b\n"
3465
3466 // Store accumulators
3467 "mov x0, %[accum_ptr]\n"
3468 "st1 {v8.16b}, [x0], #16\n"
3469 "st1 {v16.16b}, [x0], #16\n"
3470 "st1 {v24.16b}, [x0], #16\n"
3471 "st1 {v9.16b}, [x0], #16\n"
3472 "st1 {v17.16b}, [x0], #16\n"
3473 "st1 {v25.16b}, [x0], #16\n"
3474 "st1 {v10.16b}, [x0], #16\n"
3475 "st1 {v18.16b}, [x0], #16\n"
3476 "st1 {v26.16b}, [x0], #16\n"
3477 "st1 {v11.16b}, [x0], #16\n"
3478 "st1 {v19.16b}, [x0], #16\n"
3479 "st1 {v27.16b}, [x0], #16\n"
3480 "st1 {v12.16b}, [x0], #16\n"
3481 "st1 {v20.16b}, [x0], #16\n"
3482 "st1 {v28.16b}, [x0], #16\n"
3483 "st1 {v13.16b}, [x0], #16\n"
3484 "st1 {v21.16b}, [x0], #16\n"
3485 "st1 {v29.16b}, [x0], #16\n"
3486 "st1 {v14.16b}, [x0], #16\n"
3487 "st1 {v22.16b}, [x0], #16\n"
3488 "st1 {v30.16b}, [x0], #16\n"
3489 "st1 {v15.16b}, [x0], #16\n"
3490 "st1 {v23.16b}, [x0], #16\n"
3491 "st1 {v31.16b}, [x0], #16\n"
3492 : // outputs
3493 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3494 [depth] "+r"(depth)
3495 : // inputs
3496 [accum_ptr] "r"(accum_ptr)
3497 : // clobbers
3498 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
3499 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
3500 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
3501 "v27", "v28", "v29", "v30", "v31");
3502 }
3503 };
3504 #endif
3505
3506 // Faster kernel contributed by ARM. Tuned for A55r1.
3507 struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 {
3508 typedef float OperandType;
3509 typedef float AccumulatorType;
3510 typedef KernelFormat<
3511 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3512 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
3513 Format;
RunNEON_64bit_GEMM_Float32_WithScalar_A55r13514 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3515 AccumulatorType* accum_ptr, int depth) {
3516 asm volatile(
3517 // Load accumulators
3518 "mov x0, %[accum_ptr]\n"
3519 "ld1 {v8.4s}, [x0], #16\n"
3520 "ld1 {v16.4s}, [x0], #16\n"
3521 "ld1 {v24.4s}, [x0], #16\n"
3522 "ld1 {v9.4s}, [x0], #16\n"
3523 "ld1 {v17.4s}, [x0], #16\n"
3524 "ld1 {v25.4s}, [x0], #16\n"
3525 "ld1 {v10.4s}, [x0], #16\n"
3526 "ld1 {v18.4s}, [x0], #16\n"
3527 "ld1 {v26.4s}, [x0], #16\n"
3528 "ld1 {v11.4s}, [x0], #16\n"
3529 "ld1 {v19.4s}, [x0], #16\n"
3530 "ld1 {v27.4s}, [x0], #16\n"
3531 "ld1 {v12.4s}, [x0], #16\n"
3532 "ld1 {v20.4s}, [x0], #16\n"
3533 "ld1 {v28.4s}, [x0], #16\n"
3534 "ld1 {v13.4s}, [x0], #16\n"
3535 "ld1 {v21.4s}, [x0], #16\n"
3536 "ld1 {v29.4s}, [x0], #16\n"
3537 "ld1 {v14.4s}, [x0], #16\n"
3538 "ld1 {v22.4s}, [x0], #16\n"
3539 "ld1 {v30.4s}, [x0], #16\n"
3540 "ld1 {v15.4s}, [x0], #16\n"
3541 "ld1 {v23.4s}, [x0], #16\n"
3542 "ld1 {v31.4s}, [x0], #16\n"
3543
3544 // A55r1 requires a hybrid of the A53 and standard approaches.
3545 //
3546 // Like A53, this processor prefers 64-bit loads.
3547 //
3548 // Unlike A53, it is capable of dual-issuing a 64-bit vector load
3549 // (or INS) with a FMLA instruction.
3550 //
3551 // Therefore we aim to issue an FMLA instruction every cycle.
3552 // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a
3553 // scalar 64-bit load and finally an INS to replicate the effect of
3554 // a single 128-bit load.
3555 //
3556 // The loop contains 24 FMLA instructions, and 5 vector registers
3557 // need to be loaded, consuming 15 dual issue slots. This leaves 9
3558 // dual issue slots. Four of these are used for loop housekeeping
3559 // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left
3560 // over (marked by blank lines).
3561 //
3562 // Choice of x18 to store the upper halves on their way into the
3563 // vector registers is arbitrary. Added to the clobber list so that
3564 // the compiler will make it available.
3565
3566
3567 // At the start of the loop, it is assumed that v0 is "half loaded" -
3568 // bottom half in place in d0 and the upper half in x18 ready to
3569 // insert. So set that up here for the first iteration:
3570 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
3571 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
3572
3573 // v2-v3 should be fully loaded - as it's outside the loop proper it's fine
3574 // to use a 128-bit load here.
3575 "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell
3576 "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell
3577
3578 GEMMLOWP_LABEL_LOOP
3579 ":\n"
3580
3581 "fmla v8.4s, v2.4s, v0.s[0]\n"
3582 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
3583 "fmla v9.4s, v2.4s, v0.s[1]\n"
3584 "ins v0.d[1], x18\n" // Finish loading v0
3585 "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure.
3586 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
3587 "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure.
3588 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
3589 "fmla v10.4s, v2.4s, v0.s[2]\n"
3590 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
3591 "fmla v11.4s, v2.4s, v0.s[3]\n"
3592 "ins v1.d[1], x18\n" // Finish loading v1
3593 "fmla v12.4s, v2.4s, v1.s[0]\n"
3594 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
3595 "fmla v13.4s, v2.4s, v1.s[1]\n"
3596 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
3597 "fmla v14.4s, v2.4s, v1.s[2]\n"
3598
3599 "fmla v15.4s, v2.4s, v1.s[3]\n"
3600 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
3601 "fmla v18.4s, v3.4s, v0.s[2]\n"
3602 "ins v4.d[1], x18\n" // Finish loading v4
3603 "fmla v19.4s, v3.4s, v0.s[3]\n"
3604 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
3605 "fmla v20.4s, v3.4s, v1.s[0]\n"
3606 "subs %w[depth], %w[depth], #1\n"
3607 "fmla v21.4s, v3.4s, v1.s[1]\n"
3608
3609 "fmla v22.4s, v3.4s, v1.s[2]\n"
3610
3611 "fmla v23.4s, v3.4s, v1.s[3]\n"
3612 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
3613 "fmla v24.4s, v4.4s, v0.s[0]\n"
3614 "ins v2.d[1], x18\n" // Finish loading next v2
3615 "fmla v25.4s, v4.4s, v0.s[1]\n"
3616 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
3617 "fmla v26.4s, v4.4s, v0.s[2]\n"
3618
3619 "fmla v27.4s, v4.4s, v0.s[3]\n"
3620 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
3621 "fmla v28.4s, v4.4s, v1.s[0]\n"
3622 "ins v3.d[1], x18\n" // Finish loading next v3
3623 "fmla v29.4s, v4.4s, v1.s[1]\n"
3624 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
3625 "fmla v30.4s, v4.4s, v1.s[2]\n"
3626
3627 "fmla v31.4s, v4.4s, v1.s[3]\n"
3628 "bne " GEMMLOWP_LABEL_LOOP "b\n"
3629
3630 // Store accumulators
3631 "mov x0, %[accum_ptr]\n"
3632 "st1 {v8.4s}, [x0], #16\n"
3633 "st1 {v16.4s}, [x0], #16\n"
3634 "st1 {v24.4s}, [x0], #16\n"
3635 "st1 {v9.4s}, [x0], #16\n"
3636 "st1 {v17.4s}, [x0], #16\n"
3637 "st1 {v25.4s}, [x0], #16\n"
3638 "st1 {v10.4s}, [x0], #16\n"
3639 "st1 {v18.4s}, [x0], #16\n"
3640 "st1 {v26.4s}, [x0], #16\n"
3641 "st1 {v11.4s}, [x0], #16\n"
3642 "st1 {v19.4s}, [x0], #16\n"
3643 "st1 {v27.4s}, [x0], #16\n"
3644 "st1 {v12.4s}, [x0], #16\n"
3645 "st1 {v20.4s}, [x0], #16\n"
3646 "st1 {v28.4s}, [x0], #16\n"
3647 "st1 {v13.4s}, [x0], #16\n"
3648 "st1 {v21.4s}, [x0], #16\n"
3649 "st1 {v29.4s}, [x0], #16\n"
3650 "st1 {v14.4s}, [x0], #16\n"
3651 "st1 {v22.4s}, [x0], #16\n"
3652 "st1 {v30.4s}, [x0], #16\n"
3653 "st1 {v15.4s}, [x0], #16\n"
3654 "st1 {v23.4s}, [x0], #16\n"
3655 "st1 {v31.4s}, [x0], #16\n"
3656 : // outputs
3657 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
3658 [depth] "+r"(depth)
3659 : // inputs
3660 [accum_ptr] "r"(accum_ptr)
3661 : // clobbers
3662 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
3663 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
3664 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
3665 "v27", "v28", "v29", "v30", "v31");
3666 }
3667 };
3668
3669 #endif // __aarch64__
3670
3671 #if defined(__arm__) || defined(__aarch64__)
3672 #ifndef __aarch64__
vpaddq_s32(int32x4_t a,int32x4_t b)3673 inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
3674 const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
3675 const int32x2_t d = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
3676 return vcombine_s32(c, d);
3677 }
3678 #endif
3679
3680 // C++ intrinsics-based variant of the deep, int8, fast kernel
3681 template <int Cols>
3682 struct NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics {
3683 typedef std::int8_t OperandType;
3684 typedef std::int32_t AccumulatorType;
3685 typedef KernelFormat<
3686 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
3687 KernelSideFormat<CellFormat<Cols, 16, CellOrder::WidthMajor>, 1> >
3688 Format;
RunNEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics3689 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3690 AccumulatorType* accum_ptr, int depth) {
3691 int32x4_t acc[4][Cols];
3692 for (int i = 0; i < 4; i++) {
3693 for (int j = 0; j < Cols; j++) {
3694 acc[i][j] = vdupq_n_s32(0);
3695 }
3696 }
3697 for (int d = 0; d < depth; d += 16) {
3698 int8x16_t lhs[4];
3699 for (int i = 0; i < 4; i++) {
3700 lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
3701 }
3702 int8x16_t rhs[Cols];
3703 for (int i = 0; i < Cols; i++) {
3704 rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
3705 }
3706 for (int i = 0; i < 4; i++) {
3707 for (int j = 0; j < Cols; j++) {
3708 int16x8_t local_acc =
3709 vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j]));
3710 local_acc =
3711 vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j]));
3712 acc[i][j] = vpadalq_s16(acc[i][j], local_acc);
3713 }
3714 }
3715 lhs_ptr += 64;
3716 rhs_ptr += 16 * Cols;
3717 }
3718 for (int i = 0; i < Cols; i++) {
3719 int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]);
3720 int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]);
3721 int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1);
3722 int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i);
3723 dst_val = vaddq_s32(dst_val, acc_4x);
3724 vst1q_s32(accum_ptr + 4 * i, dst_val);
3725 }
3726 }
3727 };
3728
3729 using NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics =
3730 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<4>;
3731
3732 using NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics =
3733 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<2>;
3734
3735 // C++ intrinsics-based variant of the wide, uint8, general kernel
3736 template <int RhsCells>
3737 struct NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics {
3738 typedef std::uint8_t OperandType;
3739 typedef std::int32_t AccumulatorType;
3740 typedef KernelFormat<
3741 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
3742 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, RhsCells> >
3743 Format;
RunNEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics3744 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3745 AccumulatorType* accum_ptr, int depth) {
3746 int32x4_t acc[3][4 * RhsCells];
3747 for (int i = 0; i < 3; i++) {
3748 for (int j = 0; j < 4 * RhsCells; j++) {
3749 acc[i][j] = vld1q_s32(accum_ptr + 4 * (i + 3 * j));
3750 }
3751 }
3752 for (int d = 0; d < depth; d += 2) {
3753 int16x8_t lhs[3];
3754 for (int i = 0; i < 3; i++) {
3755 lhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(lhs_ptr + 8 * i)));
3756 }
3757 int16x8_t rhs[RhsCells];
3758 for (int i = 0; i < RhsCells; i++) {
3759 rhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(rhs_ptr + 8 * i)));
3760 }
3761 for (int i = 0; i < 3; i++) {
3762 for (int j = 0; j < RhsCells; j++) {
3763 acc[i][4 * j + 0] = vmlal_lane_s16(
3764 acc[i][4 * j + 0], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 0);
3765 acc[i][4 * j + 1] = vmlal_lane_s16(
3766 acc[i][4 * j + 1], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 1);
3767 acc[i][4 * j + 2] = vmlal_lane_s16(
3768 acc[i][4 * j + 2], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 2);
3769 acc[i][4 * j + 3] = vmlal_lane_s16(
3770 acc[i][4 * j + 3], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 3);
3771 acc[i][4 * j + 0] =
3772 vmlal_lane_s16(acc[i][4 * j + 0], vget_high_s16(lhs[i]),
3773 vget_high_s16(rhs[j]), 0);
3774 acc[i][4 * j + 1] =
3775 vmlal_lane_s16(acc[i][4 * j + 1], vget_high_s16(lhs[i]),
3776 vget_high_s16(rhs[j]), 1);
3777 acc[i][4 * j + 2] =
3778 vmlal_lane_s16(acc[i][4 * j + 2], vget_high_s16(lhs[i]),
3779 vget_high_s16(rhs[j]), 2);
3780 acc[i][4 * j + 3] =
3781 vmlal_lane_s16(acc[i][4 * j + 3], vget_high_s16(lhs[i]),
3782 vget_high_s16(rhs[j]), 3);
3783 }
3784 }
3785 lhs_ptr += 24;
3786 rhs_ptr += 8 * RhsCells;
3787 }
3788 for (int i = 0; i < 3; i++) {
3789 for (int j = 0; j < 4 * RhsCells; j++) {
3790 vst1q_s32(accum_ptr + 4 * (i + 3 * j), acc[i][j]);
3791 }
3792 }
3793 }
3794 };
3795
3796 using NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics =
3797 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<1>;
3798
3799 using NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics =
3800 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<2>;
3801
3802 template <int RhsCells>
3803 struct NEON_GEMM_Float32_WithScalar_intrinsics {
3804 typedef float OperandType;
3805 typedef float AccumulatorType;
3806 typedef KernelFormat<
3807 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
3808 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, RhsCells> >
3809 Format;
RunNEON_GEMM_Float32_WithScalar_intrinsics3810 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3811 AccumulatorType* accum_ptr, int depth) {
3812 float32x4_t acc[3][4 * RhsCells];
3813 for (int i = 0; i < 3; i++) {
3814 for (int j = 0; j < 4 * RhsCells; j++) {
3815 acc[i][j] = vld1q_f32(accum_ptr + 4 * (i + 3 * j));
3816 }
3817 }
3818 for (int d = 0; d < depth; d++) {
3819 float32x4_t lhs[3];
3820 for (int i = 0; i < 3; i++) {
3821 lhs[i] = vld1q_f32(lhs_ptr + 4 * i);
3822 }
3823 float32x4_t rhs[RhsCells];
3824 for (int i = 0; i < RhsCells; i++) {
3825 rhs[i] = vld1q_f32(rhs_ptr + 4 * i);
3826 }
3827 for (int i = 0; i < 3; i++) {
3828 for (int j = 0; j < RhsCells; j++) {
3829 acc[i][4 * j + 0] = vmlaq_lane_f32(acc[i][4 * j + 0], lhs[i],
3830 vget_low_f32(rhs[j]), 0);
3831 acc[i][4 * j + 1] = vmlaq_lane_f32(acc[i][4 * j + 1], lhs[i],
3832 vget_low_f32(rhs[j]), 1);
3833 acc[i][4 * j + 2] = vmlaq_lane_f32(acc[i][4 * j + 2], lhs[i],
3834 vget_high_f32(rhs[j]), 0);
3835 acc[i][4 * j + 3] = vmlaq_lane_f32(acc[i][4 * j + 3], lhs[i],
3836 vget_high_f32(rhs[j]), 1);
3837 }
3838 }
3839 lhs_ptr += 12;
3840 rhs_ptr += 4 * RhsCells;
3841 }
3842 for (int i = 0; i < 3; i++) {
3843 for (int j = 0; j < 4 * RhsCells; j++) {
3844 vst1q_f32(accum_ptr + 4 * (i + 3 * j), acc[i][j]);
3845 }
3846 }
3847 }
3848 };
3849
3850 using NEON_32bit_GEMM_Float32_WithScalar_intrinsics =
3851 NEON_GEMM_Float32_WithScalar_intrinsics<1>;
3852
3853 using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
3854 NEON_GEMM_Float32_WithScalar_intrinsics<2>;
3855 #endif // __arm__ || __aarch64__
3856
3857 #ifdef __mips
workaround_msa_maddv_w(v4i32 a,v4i32 b,v4i32 c)3858 static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
3859 // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
3860 #if 0
3861 return __builtin_msa_maddv_w(a, b, c);
3862 #else
3863 asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
3864 // Outputs
3865 : [a] "+f"(a)
3866 // Inputs
3867 : [b] "f"(b), [c] "f"(c));
3868 return a;
3869 #endif
3870 }
3871
3872 // Using 32x32=32 multiplications.
3873 // 20 MSA regs used:
3874 // - 12 accumulators
3875 // - 6 lhs
3876 // - 1 rhs
3877 // - 1 temps/zeroes
3878 // ~55 instructions in the loop.
3879 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics {
3880 typedef std::uint8_t OperandType;
3881 typedef std::int32_t AccumulatorType;
3882 typedef KernelFormat<
3883 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
3884 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
3885 Format;
RunMSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics3886 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
3887 AccumulatorType* accum_ptr, int depth) {
3888 const v16i8 zeroes = __builtin_msa_ldi_b(0);
3889 v4i32 acc[3][4];
3890 // Load accumulators.
3891 for (int i = 0; i < 3; i++) {
3892 for (int j = 0; j < 4; j++) {
3893 acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
3894 }
3895 }
3896
3897 while (depth > 0) {
3898 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
3899 v8i16 lhs[6];
3900 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
3901 lhs[1] =
3902 reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
3903
3904 // Zero-extend 8-bit elements of lhs[] to 16 bits.
3905 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
3906 reinterpret_cast<v16i8>(lhs[0])));
3907 lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
3908 reinterpret_cast<v16i8>(lhs[1])));
3909 lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
3910 reinterpret_cast<v16i8>(lhs[1])));
3911
3912 // Zero-extend 16-bit elements of lhs[] to 32 bits.
3913 lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
3914 lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
3915 lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
3916 lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
3917 lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
3918 lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
3919
3920 // Depth 0.
3921 for (int j = 0; j < 4; j++) {
3922 // Load 1 byte of rhs, making 4 32-bit replicas of it.
3923 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
3924 // Multiply-add into accumulators.
3925 for (int i = 0; i < 3; i++) {
3926 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
3927 }
3928 }
3929
3930 // Depth 1.
3931 for (int j = 0; j < 4; j++) {
3932 // Load 1 byte of rhs, making 4 32-bit replicas of it.
3933 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
3934 // Multiply-add into accumulators.
3935 for (int i = 0; i < 3; i++) {
3936 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
3937 }
3938 }
3939
3940 lhs_ptr += 24;
3941 rhs_ptr += 8;
3942 depth -= 2;
3943 }
3944
3945 // Store accumulators.
3946 for (int i = 0; i < 3; i++) {
3947 for (int j = 0; j < 4; j++) {
3948 __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
3949 }
3950 }
3951 }
3952 };
3953
3954 // Assembly implementation of the above
3955 // MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics.
3956 // Using 32x32=32 multiplications.
3957 // 20 MSA regs used:
3958 // - 12 accumulators
3959 // - 6 lhs
3960 // - 1 rhs
3961 // - 1 temps/zeroes
3962 // ~55 instructions in the loop.
3963 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly {
3964 typedef std::uint8_t OperandType;
3965 typedef std::int32_t AccumulatorType;
3966 typedef KernelFormat<
3967 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
3968 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
3969 Format;
RunMSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly3970 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
3971 AccumulatorType* accum_ptr, int depth) {
3972 asm volatile(
3973 // Load accumulators
3974 "ld.w $w0, (0*16)(%[accum_ptr])\n"
3975 "ld.w $w4, (1*16)(%[accum_ptr])\n"
3976 "ld.w $w8, (2*16)(%[accum_ptr])\n"
3977 "ld.w $w1, (3*16)(%[accum_ptr])\n"
3978 "ld.w $w5, (4*16)(%[accum_ptr])\n"
3979 "ld.w $w9, (5*16)(%[accum_ptr])\n"
3980 "ld.w $w2, (6*16)(%[accum_ptr])\n"
3981 "ld.w $w6, (7*16)(%[accum_ptr])\n"
3982 "ld.w $w10, (8*16)(%[accum_ptr])\n"
3983 "ld.w $w3, (9*16)(%[accum_ptr])\n"
3984 "ld.w $w7, (10*16)(%[accum_ptr])\n"
3985 "ld.w $w11, (11*16)(%[accum_ptr])\n"
3986 // Set a temp to all zeroes.
3987 "ldi.b $w19, 0\n"
3988
3989 GEMMLOWP_LABEL_LOOP ":\n"
3990 // Overview of register layout:
3991 //
3992 // A half of the 2x4 cell of Rhs is stored in 32bit in w18.
3993 // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17.
3994 // A 12x4 block of accumulators is stored in 32bit in w0-w11.
3995 //
3996 // +------+------+------+------+
3997 // Rhs |w18[0]|w18[1]|w18[2]|w18[3]|
3998 // +------+------+------+------+
3999 //
4000 // | | | | |
4001 //
4002 // Lhs | | | | |
4003 //
4004 // +---+---+ - - - - +------+------+------+------+
4005 // |w12|w15| | w0 | w1 | w2 | w3 |
4006 // |w12|w15| | w0 | w1 | w2 | w3 |
4007 // |w12|w15| | w0 | w1 | w2 | w3 |
4008 // |w12|w15| | w0 | w1 | w2 | w3 |
4009 // +---+---+ - - - - +------+------+------+------+
4010 // |w13|w16| | w4 | w5 | w6 | w7 |
4011 // |w13|w16| | w4 | w5 | w6 | w7 |
4012 // |w13|w16| | w4 | w5 | w6 | w7 |
4013 // |w13|w16| | w4 | w5 | w6 | w7 |
4014 // +---+---+ - - - - +------+------+------+------+
4015 // |w14|w17| | w8 | w9 | w10 | w11 |
4016 // |w14|w17| | w8 | w9 | w10 | w11 |
4017 // |w14|w17| | w8 | w9 | w10 | w11 |
4018 // |w14|w17| | w8 | w9 | w10 | w11 |
4019 // +---+---+ - - - - +------+------+------+------+
4020 //
4021 // Accumulator
4022
4023 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4024 "ld.b $w12, 0(%[lhs_ptr])\n"
4025 "ld.b $w13, 8(%[lhs_ptr])\n"
4026
4027 // Load 4 bytes of rhs[] for depth 0.
4028 "lbu $a0, 0(%[rhs_ptr])\n"
4029 "lbu $a1, 1(%[rhs_ptr])\n"
4030 "lbu $a2, 2(%[rhs_ptr])\n"
4031 "lbu $a3, 3(%[rhs_ptr])\n"
4032
4033 // Zero-extend 8-bit elements of lhs[] to 16 bits.
4034 "ilvr.b $w12, $w19, $w12\n"
4035 "ilvl.b $w14, $w19, $w13\n"
4036 "ilvr.b $w13, $w19, $w13\n"
4037 // Zero-extend 16-bit elements of lhs[] to 32 bits.
4038 "ilvl.h $w15, $w19, $w12\n"
4039 "ilvl.h $w16, $w19, $w13\n"
4040 "ilvl.h $w17, $w19, $w14\n"
4041 "ilvr.h $w12, $w19, $w12\n"
4042 "ilvr.h $w13, $w19, $w13\n"
4043 "ilvr.h $w14, $w19, $w14\n"
4044
4045 // Depth 0.
4046 "fill.w $w18, $a0\n"
4047 "lbu $a0, 4(%[rhs_ptr])\n"
4048 "maddv.w $w0, $w12, $w18\n"
4049 "maddv.w $w4, $w13, $w18\n"
4050 "maddv.w $w8, $w14, $w18\n"
4051 "fill.w $w18, $a1\n"
4052 "lbu $a1, 5(%[rhs_ptr])\n"
4053 "maddv.w $w1, $w12, $w18\n"
4054 "maddv.w $w5, $w13, $w18\n"
4055 "maddv.w $w9, $w14, $w18\n"
4056 "fill.w $w18, $a2\n"
4057 "lbu $a2, 6(%[rhs_ptr])\n"
4058 "maddv.w $w2, $w12, $w18\n"
4059 "maddv.w $w6, $w13, $w18\n"
4060 "maddv.w $w10, $w14, $w18\n"
4061 "fill.w $w18, $a3\n"
4062 "lbu $a3, 7(%[rhs_ptr])\n"
4063 "maddv.w $w3, $w12, $w18\n"
4064 "maddv.w $w7, $w13, $w18\n"
4065 "maddv.w $w11, $w14, $w18\n"
4066
4067 // Depth 1.
4068 "fill.w $w18, $a0\n"
4069 "maddv.w $w0, $w15, $w18\n"
4070 "maddv.w $w4, $w16, $w18\n"
4071 "maddv.w $w8, $w17, $w18\n"
4072 "fill.w $w18, $a1\n"
4073 "maddv.w $w1, $w15, $w18\n"
4074 "maddv.w $w5, $w16, $w18\n"
4075 "maddv.w $w9, $w17, $w18\n"
4076 "fill.w $w18, $a2\n"
4077 "maddv.w $w2, $w15, $w18\n"
4078 "maddv.w $w6, $w16, $w18\n"
4079 "maddv.w $w10, $w17, $w18\n"
4080 "fill.w $w18, $a3\n"
4081 "maddv.w $w3, $w15, $w18\n"
4082 "maddv.w $w7, $w16, $w18\n"
4083 "maddv.w $w11, $w17, $w18\n"
4084
4085 "addiu %[depth], -2\n"
4086 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
4087 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
4088 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
4089
4090 // Store accumulators.
4091 "st.w $w0, (0*16)(%[accum_ptr])\n"
4092 "st.w $w4, (1*16)(%[accum_ptr])\n"
4093 "st.w $w8, (2*16)(%[accum_ptr])\n"
4094 "st.w $w1, (3*16)(%[accum_ptr])\n"
4095 "st.w $w5, (4*16)(%[accum_ptr])\n"
4096 "st.w $w9, (5*16)(%[accum_ptr])\n"
4097 "st.w $w2, (6*16)(%[accum_ptr])\n"
4098 "st.w $w6, (7*16)(%[accum_ptr])\n"
4099 "st.w $w10, (8*16)(%[accum_ptr])\n"
4100 "st.w $w3, (9*16)(%[accum_ptr])\n"
4101 "st.w $w7, (10*16)(%[accum_ptr])\n"
4102 "st.w $w11, (11*16)(%[accum_ptr])\n"
4103 : // outputs
4104 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4105 [depth] "+r"(depth)
4106 : // inputs
4107 [accum_ptr] "r"(accum_ptr)
4108 : // clobbers
4109 "memory",
4110 "a0", "a1", "a2", "a3",
4111 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
4112 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
4113 "$f16", "$f17", "$f18", "$f19");
4114 }
4115 };
4116
4117 // Assembly implementation of the above
4118 // MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
4119 // Using 16x16=32 multiplications.
4120 // 20 MSA regs used:
4121 // - 12 accumulators
4122 // - 3 lhs
4123 // - 4 rhs
4124 // - 1 temps/zeroes
4125 // ~45 instructions in the loop.
4126 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 {
4127 typedef std::uint8_t OperandType;
4128 typedef std::int32_t AccumulatorType;
4129 typedef KernelFormat<
4130 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4131 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
4132 Format;
RunMSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly24133 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
4134 AccumulatorType* accum_ptr, int depth) {
4135 asm volatile(
4136 // Load accumulators
4137 "ld.w $w0, (0*16)(%[accum_ptr])\n"
4138 "ld.w $w4, (1*16)(%[accum_ptr])\n"
4139 "ld.w $w8, (2*16)(%[accum_ptr])\n"
4140 "ld.w $w1, (3*16)(%[accum_ptr])\n"
4141 "ld.w $w5, (4*16)(%[accum_ptr])\n"
4142 "ld.w $w9, (5*16)(%[accum_ptr])\n"
4143 "ld.w $w2, (6*16)(%[accum_ptr])\n"
4144 "ld.w $w6, (7*16)(%[accum_ptr])\n"
4145 "ld.w $w10, (8*16)(%[accum_ptr])\n"
4146 "ld.w $w3, (9*16)(%[accum_ptr])\n"
4147 "ld.w $w7, (10*16)(%[accum_ptr])\n"
4148 "ld.w $w11, (11*16)(%[accum_ptr])\n"
4149 // Set a temp to all zeroes.
4150 "ldi.b $w19, 0\n"
4151
4152 GEMMLOWP_LABEL_LOOP ":\n"
4153 // Overview of register layout:
4154 //
4155 // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register
4156 // contains 4 replicas of a pair of elements).
4157 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14.
4158 // A 12x4 block of accumulators is stored in 32bit in w0-w11.
4159 //
4160 // +-----+-----+-----+-----+
4161 // Rhs | w15 | w16 | w17 | w18 |
4162 // +-----+-----+-----+-----+
4163 //
4164 // | | | | |
4165 //
4166 // Lhs | | | | |
4167 //
4168 // +---+ - - - - +-----+-----+-----+-----+
4169 // |w12| | w0 | w1 | w2 | w3 |
4170 // |w12| | w0 | w1 | w2 | w3 |
4171 // |w12| | w0 | w1 | w2 | w3 |
4172 // |w12| | w0 | w1 | w2 | w3 |
4173 // +---+ - - - - +-----+-----+-----+-----+
4174 // |w13| | w4 | w5 | w6 | w7 |
4175 // |w13| | w4 | w5 | w6 | w7 |
4176 // |w13| | w4 | w5 | w6 | w7 |
4177 // |w13| | w4 | w5 | w6 | w7 |
4178 // +---+ - - - - +-----+-----+-----+-----+
4179 // |w14| | w8 | w9 | w10 | w11 |
4180 // |w14| | w8 | w9 | w10 | w11 |
4181 // |w14| | w8 | w9 | w10 | w11 |
4182 // |w14| | w8 | w9 | w10 | w11 |
4183 // +---+ - - - - +-----+-----+-----+-----+
4184 //
4185 // Accumulators
4186
4187 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4188 "ld.b $w12, 0(%[lhs_ptr])\n"
4189 "ld.b $w13, 8(%[lhs_ptr])\n"
4190
4191 // Load 4 bytes of rhs[] for depth 0.
4192 "lbu $a0, 0(%[rhs_ptr])\n"
4193 "lbu $a1, 1(%[rhs_ptr])\n"
4194 "lbu $a2, 2(%[rhs_ptr])\n"
4195 "lbu $a3, 3(%[rhs_ptr])\n"
4196 // Load 4 bytes of rhs[] for depth 1.
4197 "lbu $v0, 4(%[rhs_ptr])\n"
4198 "lbu $v1, 5(%[rhs_ptr])\n"
4199 "lbu $t8, 6(%[rhs_ptr])\n"
4200 "lbu $t9, 7(%[rhs_ptr])\n"
4201
4202 // Zero-extend 8-bit elements of lhs[] to 16 bits.
4203 "ilvr.b $w12, $w19, $w12\n"
4204 "ilvl.b $w14, $w19, $w13\n"
4205 "ilvr.b $w13, $w19, $w13\n"
4206 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
4207 "ilvl.d $w15, $w19, $w12\n"
4208 "ilvl.d $w16, $w19, $w13\n"
4209 "ilvl.d $w17, $w19, $w14\n"
4210 "ilvr.h $w12, $w15, $w12\n"
4211 "ilvr.h $w13, $w16, $w13\n"
4212 "ilvr.h $w14, $w17, $w14\n"
4213
4214 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w.
4215 "ins $a0, $v0, 16, 8\n"
4216 "ins $a1, $v1, 16, 8\n"
4217 "ins $a2, $t8, 16, 8\n"
4218 "ins $a3, $t9, 16, 8\n"
4219 // Make 4 replicas of every pair of rhs[] elements.
4220 "fill.w $w15, $a0\n"
4221 "fill.w $w16, $a1\n"
4222 "fill.w $w17, $a2\n"
4223 "fill.w $w18, $a3\n"
4224
4225 // Depths 0 and 1.
4226 // Dot-product-(and)-add doubles multiplicand width.
4227 "dpadd_u.w $w0, $w12, $w15\n"
4228 "dpadd_u.w $w4, $w13, $w15\n"
4229 "dpadd_u.w $w8, $w14, $w15\n"
4230 "dpadd_u.w $w1, $w12, $w16\n"
4231 "dpadd_u.w $w5, $w13, $w16\n"
4232 "dpadd_u.w $w9, $w14, $w16\n"
4233 "dpadd_u.w $w2, $w12, $w17\n"
4234 "dpadd_u.w $w6, $w13, $w17\n"
4235 "dpadd_u.w $w10, $w14, $w17\n"
4236 "dpadd_u.w $w3, $w12, $w18\n"
4237 "dpadd_u.w $w7, $w13, $w18\n"
4238 "dpadd_u.w $w11, $w14, $w18\n"
4239
4240 "addiu %[depth], -2\n"
4241 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
4242 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
4243 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
4244
4245 // Store accumulators.
4246 "st.w $w0, (0*16)(%[accum_ptr])\n"
4247 "st.w $w4, (1*16)(%[accum_ptr])\n"
4248 "st.w $w8, (2*16)(%[accum_ptr])\n"
4249 "st.w $w1, (3*16)(%[accum_ptr])\n"
4250 "st.w $w5, (4*16)(%[accum_ptr])\n"
4251 "st.w $w9, (5*16)(%[accum_ptr])\n"
4252 "st.w $w2, (6*16)(%[accum_ptr])\n"
4253 "st.w $w6, (7*16)(%[accum_ptr])\n"
4254 "st.w $w10, (8*16)(%[accum_ptr])\n"
4255 "st.w $w3, (9*16)(%[accum_ptr])\n"
4256 "st.w $w7, (10*16)(%[accum_ptr])\n"
4257 "st.w $w11, (11*16)(%[accum_ptr])\n"
4258 : // outputs
4259 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4260 [depth] "+r"(depth)
4261 : // inputs
4262 [accum_ptr] "r"(accum_ptr)
4263 : // clobbers
4264 "memory",
4265 "v0", "v1",
4266 "a0", "a1", "a2", "a3",
4267 "t8", "t9",
4268 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
4269 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
4270 "$f16", "$f17", "$f18", "$f19");
4271 }
4272 };
4273
4274 // Using 32x32=32 multiplications.
4275 // 32 MSA regs used:
4276 // - 24 accumulators
4277 // - 6 lhs
4278 // - 1 rhs
4279 // - 1 temps/zeroes
4280 // ~95 instructions in the loop.
4281 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics {
4282 typedef std::uint8_t OperandType;
4283 typedef std::uint32_t AccumulatorType;
4284 typedef KernelFormat<
4285 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4286 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
4287 Format;
RunMSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics4288 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4289 AccumulatorType* accum_ptr, int depth) {
4290 const v16i8 zeroes = __builtin_msa_ldi_b(0);
4291 v4i32 acc[3][8];
4292 // Load accumulators.
4293 for (int i = 0; i < 3; i++) {
4294 for (int j = 0; j < 8; j++) {
4295 acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
4296 }
4297 }
4298
4299 while (depth > 0) {
4300 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4301 v8i16 lhs[6];
4302 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
4303 lhs[1] =
4304 reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
4305
4306 // Zero-extend 8-bit elements of lhs[] to 16 bits.
4307 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
4308 reinterpret_cast<v16i8>(lhs[0])));
4309 lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
4310 reinterpret_cast<v16i8>(lhs[1])));
4311 lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
4312 reinterpret_cast<v16i8>(lhs[1])));
4313
4314 // Zero-extend 16-bit elements of lhs[] to 32 bits.
4315 lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
4316 lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
4317 lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
4318 lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
4319 lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
4320 lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
4321
4322 // Depth 0.
4323 for (int j = 0; j < 4; j++) {
4324 // Load 1 byte of rhs, making 4 32-bit replicas of it.
4325 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
4326 // Multiply-add into accumulators.
4327 for (int i = 0; i < 3; i++) {
4328 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
4329 }
4330 }
4331 for (int j = 4; j < 8; j++) {
4332 // Load 1 byte of rhs, making 4 32-bit replicas of it.
4333 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
4334 // Multiply-add into accumulators.
4335 for (int i = 0; i < 3; i++) {
4336 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
4337 }
4338 }
4339
4340 // Depth 1.
4341 for (int j = 0; j < 4; j++) {
4342 // Load 1 byte of rhs, making 4 32-bit replicas of it.
4343 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
4344 // Multiply-add into accumulators.
4345 for (int i = 0; i < 3; i++) {
4346 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
4347 }
4348 }
4349 for (int j = 4; j < 8; j++) {
4350 // Load 1 byte of rhs, making 4 32-bit replicas of it.
4351 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8]));
4352 // Multiply-add into accumulators.
4353 for (int i = 0; i < 3; i++) {
4354 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
4355 }
4356 }
4357
4358 lhs_ptr += 24;
4359 rhs_ptr += 16;
4360 depth -= 2;
4361 }
4362
4363 // Store accumulators.
4364 for (int i = 0; i < 3; i++) {
4365 for (int j = 0; j < 8; j++) {
4366 __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
4367 }
4368 }
4369 }
4370 };
4371
4372 // Assembly implementation of the above
4373 // MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics.
4374 // Using 32x32=32 multiplications.
4375 // 32 MSA regs used:
4376 // - 24 accumulators
4377 // - 6 lhs
4378 // - 1 rhs
4379 // - 1 temps/zeroes
4380 // ~95 instructions in the loop.
4381 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
4382 typedef std::uint8_t OperandType;
4383 typedef std::uint32_t AccumulatorType;
4384 typedef KernelFormat<
4385 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4386 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
4387 Format;
RunMSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly4388 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
4389 AccumulatorType* accum_ptr, int depth) {
4390 asm volatile(
4391 // Load accumulators
4392 "ld.w $w0, (0*16)(%[accum_ptr])\n"
4393 "ld.w $w4, (1*16)(%[accum_ptr])\n"
4394 "ld.w $w8, (2*16)(%[accum_ptr])\n"
4395 "ld.w $w1, (3*16)(%[accum_ptr])\n"
4396 "ld.w $w5, (4*16)(%[accum_ptr])\n"
4397 "ld.w $w9, (5*16)(%[accum_ptr])\n"
4398 "ld.w $w2, (6*16)(%[accum_ptr])\n"
4399 "ld.w $w6, (7*16)(%[accum_ptr])\n"
4400 "ld.w $w10, (8*16)(%[accum_ptr])\n"
4401 "ld.w $w3, (9*16)(%[accum_ptr])\n"
4402 "ld.w $w7, (10*16)(%[accum_ptr])\n"
4403 "ld.w $w11, (11*16)(%[accum_ptr])\n"
4404 "ld.w $w12, (12*16)(%[accum_ptr])\n"
4405 "ld.w $w16, (13*16)(%[accum_ptr])\n"
4406 "ld.w $w20, (14*16)(%[accum_ptr])\n"
4407 "ld.w $w13, (15*16)(%[accum_ptr])\n"
4408 "ld.w $w17, (16*16)(%[accum_ptr])\n"
4409 "ld.w $w21, (17*16)(%[accum_ptr])\n"
4410 "ld.w $w14, (18*16)(%[accum_ptr])\n"
4411 "ld.w $w18, (19*16)(%[accum_ptr])\n"
4412 "ld.w $w22, (20*16)(%[accum_ptr])\n"
4413 "ld.w $w15, (21*16)(%[accum_ptr])\n"
4414 "ld.w $w19, (22*16)(%[accum_ptr])\n"
4415 "ld.w $w23, (23*16)(%[accum_ptr])\n"
4416 // Set a temp to all zeroes.
4417 "ldi.b $w31, 0\n"
4418
4419 GEMMLOWP_LABEL_LOOP ":\n"
4420 // Overview of register layout:
4421 //
4422 // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30.
4423 // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29.
4424 // A 12x8 block of accumulators is stored in 32bit in w0-w23.
4425 //
4426 // +------+------+------+------+
4427 // Rhs |w30[0]|w30[1]|w30[2]|w30[3]|
4428 // +------+------+------+------+
4429 //
4430 // | | | | |
4431 //
4432 // Lhs | | | | |
4433 //
4434 // +---+---+ - - - - +------+------+------+------+
4435 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
4436 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
4437 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
4438 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
4439 // +---+---+ - - - - +------+------+------+------+
4440 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
4441 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
4442 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
4443 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
4444 // +---+---+ - - - - +------+------+------+------+
4445 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
4446 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
4447 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
4448 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
4449 // +---+---+ - - - - +------+------+------+------+
4450 //
4451 // Accumulator
4452
4453 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4454 "ld.b $w24, 0(%[lhs_ptr])\n"
4455 "ld.b $w25, 8(%[lhs_ptr])\n"
4456
4457 // Load 4 bytes of rhs[] for the first half of depth 0.
4458 "lbu $a0, 0(%[rhs_ptr])\n"
4459 "lbu $a1, 1(%[rhs_ptr])\n"
4460 "lbu $a2, 2(%[rhs_ptr])\n"
4461 "lbu $a3, 3(%[rhs_ptr])\n"
4462
4463 // Zero-extend 8-bit elements of lhs[] to 16 bits.
4464 "ilvr.b $w24, $w31, $w24\n"
4465 "ilvl.b $w26, $w31, $w25\n"
4466 "ilvr.b $w25, $w31, $w25\n"
4467 // Zero-extend 16-bit elements of lhs[] to 32 bits.
4468 "ilvl.h $w27, $w31, $w24\n"
4469 "ilvl.h $w28, $w31, $w25\n"
4470 "ilvl.h $w29, $w31, $w26\n"
4471 "ilvr.h $w24, $w31, $w24\n"
4472 "ilvr.h $w25, $w31, $w25\n"
4473 "ilvr.h $w26, $w31, $w26\n"
4474
4475 // Depth 0.
4476 "fill.w $w30, $a0\n"
4477 "lbu $a0, 8(%[rhs_ptr])\n"
4478 "maddv.w $w0, $w24, $w30\n"
4479 "maddv.w $w4, $w25, $w30\n"
4480 "maddv.w $w8, $w26, $w30\n"
4481 "fill.w $w30, $a1\n"
4482 "lbu $a1, 9(%[rhs_ptr])\n"
4483 "maddv.w $w1, $w24, $w30\n"
4484 "maddv.w $w5, $w25, $w30\n"
4485 "maddv.w $w9, $w26, $w30\n"
4486 "fill.w $w30, $a2\n"
4487 "lbu $a2, 10(%[rhs_ptr])\n"
4488 "maddv.w $w2, $w24, $w30\n"
4489 "maddv.w $w6, $w25, $w30\n"
4490 "maddv.w $w10, $w26, $w30\n"
4491 "fill.w $w30, $a3\n"
4492 "lbu $a3, 11(%[rhs_ptr])\n"
4493 "maddv.w $w3, $w24, $w30\n"
4494 "maddv.w $w7, $w25, $w30\n"
4495 "maddv.w $w11, $w26, $w30\n"
4496
4497 "fill.w $w30, $a0\n"
4498 "lbu $a0, 4(%[rhs_ptr])\n"
4499 "maddv.w $w12, $w24, $w30\n"
4500 "maddv.w $w16, $w25, $w30\n"
4501 "maddv.w $w20, $w26, $w30\n"
4502 "fill.w $w30, $a1\n"
4503 "lbu $a1, 5(%[rhs_ptr])\n"
4504 "maddv.w $w13, $w24, $w30\n"
4505 "maddv.w $w17, $w25, $w30\n"
4506 "maddv.w $w21, $w26, $w30\n"
4507 "fill.w $w30, $a2\n"
4508 "lbu $a2, 6(%[rhs_ptr])\n"
4509 "maddv.w $w14, $w24, $w30\n"
4510 "maddv.w $w18, $w25, $w30\n"
4511 "maddv.w $w22, $w26, $w30\n"
4512 "fill.w $w30, $a3\n"
4513 "lbu $a3, 7(%[rhs_ptr])\n"
4514 "maddv.w $w15, $w24, $w30\n"
4515 "maddv.w $w19, $w25, $w30\n"
4516 "maddv.w $w23, $w26, $w30\n"
4517
4518 // Depth 1.
4519 "fill.w $w30, $a0\n"
4520 "lbu $a0, 12(%[rhs_ptr])\n"
4521 "maddv.w $w0, $w27, $w30\n"
4522 "maddv.w $w4, $w28, $w30\n"
4523 "maddv.w $w8, $w29, $w30\n"
4524 "fill.w $w30, $a1\n"
4525 "lbu $a1, 13(%[rhs_ptr])\n"
4526 "maddv.w $w1, $w27, $w30\n"
4527 "maddv.w $w5, $w28, $w30\n"
4528 "maddv.w $w9, $w29, $w30\n"
4529 "fill.w $w30, $a2\n"
4530 "lbu $a2, 14(%[rhs_ptr])\n"
4531 "maddv.w $w2, $w27, $w30\n"
4532 "maddv.w $w6, $w28, $w30\n"
4533 "maddv.w $w10, $w29, $w30\n"
4534 "fill.w $w30, $a3\n"
4535 "lbu $a3, 15(%[rhs_ptr])\n"
4536 "maddv.w $w3, $w27, $w30\n"
4537 "maddv.w $w7, $w28, $w30\n"
4538 "maddv.w $w11, $w29, $w30\n"
4539
4540 "fill.w $w30, $a0\n"
4541 "maddv.w $w12, $w27, $w30\n"
4542 "maddv.w $w16, $w28, $w30\n"
4543 "maddv.w $w20, $w29, $w30\n"
4544 "fill.w $w30, $a1\n"
4545 "maddv.w $w13, $w27, $w30\n"
4546 "maddv.w $w17, $w28, $w30\n"
4547 "maddv.w $w21, $w29, $w30\n"
4548 "fill.w $w30, $a2\n"
4549 "maddv.w $w14, $w27, $w30\n"
4550 "maddv.w $w18, $w28, $w30\n"
4551 "maddv.w $w22, $w29, $w30\n"
4552 "fill.w $w30, $a3\n"
4553 "maddv.w $w15, $w27, $w30\n"
4554 "maddv.w $w19, $w28, $w30\n"
4555 "maddv.w $w23, $w29, $w30\n"
4556
4557 "addiu %[depth], -2\n"
4558 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
4559 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
4560 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
4561
4562 // Store accumulators.
4563 "st.w $w0, (0*16)(%[accum_ptr])\n"
4564 "st.w $w4, (1*16)(%[accum_ptr])\n"
4565 "st.w $w8, (2*16)(%[accum_ptr])\n"
4566 "st.w $w1, (3*16)(%[accum_ptr])\n"
4567 "st.w $w5, (4*16)(%[accum_ptr])\n"
4568 "st.w $w9, (5*16)(%[accum_ptr])\n"
4569 "st.w $w2, (6*16)(%[accum_ptr])\n"
4570 "st.w $w6, (7*16)(%[accum_ptr])\n"
4571 "st.w $w10, (8*16)(%[accum_ptr])\n"
4572 "st.w $w3, (9*16)(%[accum_ptr])\n"
4573 "st.w $w7, (10*16)(%[accum_ptr])\n"
4574 "st.w $w11, (11*16)(%[accum_ptr])\n"
4575 "st.w $w12, (12*16)(%[accum_ptr])\n"
4576 "st.w $w16, (13*16)(%[accum_ptr])\n"
4577 "st.w $w20, (14*16)(%[accum_ptr])\n"
4578 "st.w $w13, (15*16)(%[accum_ptr])\n"
4579 "st.w $w17, (16*16)(%[accum_ptr])\n"
4580 "st.w $w21, (17*16)(%[accum_ptr])\n"
4581 "st.w $w14, (18*16)(%[accum_ptr])\n"
4582 "st.w $w18, (19*16)(%[accum_ptr])\n"
4583 "st.w $w22, (20*16)(%[accum_ptr])\n"
4584 "st.w $w15, (21*16)(%[accum_ptr])\n"
4585 "st.w $w19, (22*16)(%[accum_ptr])\n"
4586 "st.w $w23, (23*16)(%[accum_ptr])\n"
4587 : // outputs
4588 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4589 [depth] "+r"(depth)
4590 : // inputs
4591 [accum_ptr] "r"(accum_ptr)
4592 : // clobbers
4593 "memory",
4594 "a0", "a1", "a2", "a3",
4595 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
4596 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
4597 "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
4598 "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
4599 }
4600 };
4601
4602 // Assembly implementation of the above
4603 // MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
4604 // Using 16x16=32 multiplications.
4605 // 32 MSA regs used:
4606 // - 24 accumulators
4607 // - 3 lhs
4608 // - 4 rhs
4609 // - 1 temps/zeroes
4610 // ~70 instructions in the loop.
4611 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
4612 typedef std::uint8_t OperandType;
4613 typedef std::uint32_t AccumulatorType;
4614 typedef KernelFormat<
4615 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
4616 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
4617 Format;
RunMSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly24618 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
4619 AccumulatorType* accum_ptr, int depth) {
4620 asm volatile(
4621 // Load accumulators
4622 "ld.w $w0, (0*16)(%[accum_ptr])\n"
4623 "ld.w $w4, (1*16)(%[accum_ptr])\n"
4624 "ld.w $w8, (2*16)(%[accum_ptr])\n"
4625 "ld.w $w1, (3*16)(%[accum_ptr])\n"
4626 "ld.w $w5, (4*16)(%[accum_ptr])\n"
4627 "ld.w $w9, (5*16)(%[accum_ptr])\n"
4628 "ld.w $w2, (6*16)(%[accum_ptr])\n"
4629 "ld.w $w6, (7*16)(%[accum_ptr])\n"
4630 "ld.w $w10, (8*16)(%[accum_ptr])\n"
4631 "ld.w $w3, (9*16)(%[accum_ptr])\n"
4632 "ld.w $w7, (10*16)(%[accum_ptr])\n"
4633 "ld.w $w11, (11*16)(%[accum_ptr])\n"
4634 "ld.w $w12, (12*16)(%[accum_ptr])\n"
4635 "ld.w $w16, (13*16)(%[accum_ptr])\n"
4636 "ld.w $w20, (14*16)(%[accum_ptr])\n"
4637 "ld.w $w13, (15*16)(%[accum_ptr])\n"
4638 "ld.w $w17, (16*16)(%[accum_ptr])\n"
4639 "ld.w $w21, (17*16)(%[accum_ptr])\n"
4640 "ld.w $w14, (18*16)(%[accum_ptr])\n"
4641 "ld.w $w18, (19*16)(%[accum_ptr])\n"
4642 "ld.w $w22, (20*16)(%[accum_ptr])\n"
4643 "ld.w $w15, (21*16)(%[accum_ptr])\n"
4644 "ld.w $w19, (22*16)(%[accum_ptr])\n"
4645 "ld.w $w23, (23*16)(%[accum_ptr])\n"
4646 // Set a temp to all zeroes.
4647 "ldi.b $w31, 0\n"
4648
4649 GEMMLOWP_LABEL_LOOP ":\n"
4650 // Overview of register layout:
4651 //
4652 // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
4653 // (each register contains 4 replicas of a pair of elements).
4654 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
4655 // A 12x8 block of accumulators is stored in 32bit in w0-w23.
4656 //
4657 // +------+------+------+------+
4658 // Rhs |w27 |w28 |w29 |w30 |
4659 // +------+------+------+------+
4660 //
4661 // | | | | |
4662 //
4663 // Lhs | | | | |
4664 //
4665 // +---+ - - - - +------+------+------+------+
4666 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4667 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4668 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4669 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
4670 // +---+ - - - - +------+------+------+------+
4671 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4672 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4673 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4674 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
4675 // +---+ - - - - +------+------+------+------+
4676 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4677 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4678 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4679 // |w26| |w8/20 |w9/21 |w10/22|w11/23|
4680 // +---+ - - - - +------+------+------+------+
4681 //
4682 // Accumulators
4683
4684 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
4685 "ld.b $w24, 0(%[lhs_ptr])\n"
4686 "ld.b $w25, 8(%[lhs_ptr])\n"
4687
4688 // Load 4 bytes of rhs[] for the first half of depth 0.
4689 "lbu $a0, 0(%[rhs_ptr])\n"
4690 "lbu $a1, 1(%[rhs_ptr])\n"
4691 "lbu $a2, 2(%[rhs_ptr])\n"
4692 "lbu $a3, 3(%[rhs_ptr])\n"
4693 // Load 4 bytes of rhs[] for the first half of depth 1.
4694 "lbu $v0, 4(%[rhs_ptr])\n"
4695 "lbu $v1, 5(%[rhs_ptr])\n"
4696 "lbu $t8, 6(%[rhs_ptr])\n"
4697 "lbu $t9, 7(%[rhs_ptr])\n"
4698
4699 // Zero-extend 8-bit elements of lhs[] to 16 bits.
4700 "ilvr.b $w24, $w31, $w24\n"
4701 "ilvl.b $w26, $w31, $w25\n"
4702 "ilvr.b $w25, $w31, $w25\n"
4703 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
4704 "ilvl.d $w27, $w31, $w24\n"
4705 "ilvl.d $w28, $w31, $w25\n"
4706 "ilvl.d $w29, $w31, $w26\n"
4707 "ilvr.h $w24, $w27, $w24\n"
4708 "ilvr.h $w25, $w28, $w25\n"
4709 "ilvr.h $w26, $w29, $w26\n"
4710
4711 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
4712 // (for the first half).
4713 "ins $a0, $v0, 16, 8\n"
4714 "ins $a1, $v1, 16, 8\n"
4715 "ins $a2, $t8, 16, 8\n"
4716 "ins $a3, $t9, 16, 8\n"
4717 // Make 4 replicas of every pair of rhs[] elements.
4718 "fill.w $w27, $a0\n"
4719 "fill.w $w28, $a1\n"
4720 "fill.w $w29, $a2\n"
4721 "fill.w $w30, $a3\n"
4722
4723 // Load 4 bytes of rhs[] for the second half of depth 0.
4724 "lbu $a0, 8(%[rhs_ptr])\n"
4725 "lbu $a1, 9(%[rhs_ptr])\n"
4726 "lbu $a2, 10(%[rhs_ptr])\n"
4727 "lbu $a3, 11(%[rhs_ptr])\n"
4728 // Load 4 bytes of rhs[] for the second half of depth 1.
4729 "lbu $v0, 12(%[rhs_ptr])\n"
4730 "lbu $v1, 13(%[rhs_ptr])\n"
4731 "lbu $t8, 14(%[rhs_ptr])\n"
4732 "lbu $t9, 15(%[rhs_ptr])\n"
4733
4734 // First half of depths 0 and 1.
4735 // Dot-product-(and)-add doubles multiplicand width.
4736 "dpadd_u.w $w0, $w24, $w27\n"
4737 "dpadd_u.w $w4, $w25, $w27\n"
4738 "dpadd_u.w $w8, $w26, $w27\n"
4739 "dpadd_u.w $w1, $w24, $w28\n"
4740 "dpadd_u.w $w5, $w25, $w28\n"
4741 "dpadd_u.w $w9, $w26, $w28\n"
4742 "dpadd_u.w $w2, $w24, $w29\n"
4743 "dpadd_u.w $w6, $w25, $w29\n"
4744 "dpadd_u.w $w10, $w26, $w29\n"
4745 "dpadd_u.w $w3, $w24, $w30\n"
4746 "dpadd_u.w $w7, $w25, $w30\n"
4747 "dpadd_u.w $w11, $w26, $w30\n"
4748
4749 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
4750 // (for the second half).
4751 "ins $a0, $v0, 16, 8\n"
4752 "ins $a1, $v1, 16, 8\n"
4753 "ins $a2, $t8, 16, 8\n"
4754 "ins $a3, $t9, 16, 8\n"
4755 // Make 4 replicas of every pair of rhs[] elements.
4756 "fill.w $w27, $a0\n"
4757 "fill.w $w28, $a1\n"
4758 "fill.w $w29, $a2\n"
4759 "fill.w $w30, $a3\n"
4760
4761 // Second half of depths 0 and 1.
4762 // Dot-product-(and)-add doubles multiplicand width.
4763 "dpadd_u.w $w12, $w24, $w27\n"
4764 "dpadd_u.w $w16, $w25, $w27\n"
4765 "dpadd_u.w $w20, $w26, $w27\n"
4766 "dpadd_u.w $w13, $w24, $w28\n"
4767 "dpadd_u.w $w17, $w25, $w28\n"
4768 "dpadd_u.w $w21, $w26, $w28\n"
4769 "dpadd_u.w $w14, $w24, $w29\n"
4770 "dpadd_u.w $w18, $w25, $w29\n"
4771 "dpadd_u.w $w22, $w26, $w29\n"
4772 "dpadd_u.w $w15, $w24, $w30\n"
4773 "dpadd_u.w $w19, $w25, $w30\n"
4774 "dpadd_u.w $w23, $w26, $w30\n"
4775
4776 "addiu %[depth], -2\n"
4777 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
4778 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
4779 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
4780
4781 // Store accumulators.
4782 "st.w $w0, (0*16)(%[accum_ptr])\n"
4783 "st.w $w4, (1*16)(%[accum_ptr])\n"
4784 "st.w $w8, (2*16)(%[accum_ptr])\n"
4785 "st.w $w1, (3*16)(%[accum_ptr])\n"
4786 "st.w $w5, (4*16)(%[accum_ptr])\n"
4787 "st.w $w9, (5*16)(%[accum_ptr])\n"
4788 "st.w $w2, (6*16)(%[accum_ptr])\n"
4789 "st.w $w6, (7*16)(%[accum_ptr])\n"
4790 "st.w $w10, (8*16)(%[accum_ptr])\n"
4791 "st.w $w3, (9*16)(%[accum_ptr])\n"
4792 "st.w $w7, (10*16)(%[accum_ptr])\n"
4793 "st.w $w11, (11*16)(%[accum_ptr])\n"
4794 "st.w $w12, (12*16)(%[accum_ptr])\n"
4795 "st.w $w16, (13*16)(%[accum_ptr])\n"
4796 "st.w $w20, (14*16)(%[accum_ptr])\n"
4797 "st.w $w13, (15*16)(%[accum_ptr])\n"
4798 "st.w $w17, (16*16)(%[accum_ptr])\n"
4799 "st.w $w21, (17*16)(%[accum_ptr])\n"
4800 "st.w $w14, (18*16)(%[accum_ptr])\n"
4801 "st.w $w18, (19*16)(%[accum_ptr])\n"
4802 "st.w $w22, (20*16)(%[accum_ptr])\n"
4803 "st.w $w15, (21*16)(%[accum_ptr])\n"
4804 "st.w $w19, (22*16)(%[accum_ptr])\n"
4805 "st.w $w23, (23*16)(%[accum_ptr])\n"
4806 : // outputs
4807 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4808 [depth] "+r"(depth)
4809 : // inputs
4810 [accum_ptr] "r"(accum_ptr)
4811 : // clobbers
4812 "memory",
4813 "v0", "v1",
4814 "a0", "a1", "a2", "a3",
4815 "t8", "t9",
4816 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
4817 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
4818 "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
4819 "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
4820 }
4821 };
4822 #endif // __mips
4823
4824 // BEGIN code copied from gemmlowp/internal/kernel_reference.h
4825
4826 // This kernel is templatized in an arbitrary Format template parameter,
4827 // allowing it to have any arbitrary format.
4828 template <typename tOperandType, typename tAccumulatorType, typename tFormat>
4829 struct ReferenceKernel {
4830 typedef tOperandType OperandType;
4831 typedef tAccumulatorType AccumulatorType;
4832 typedef tFormat Format;
4833
RunReferenceKernel4834 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
4835 AccumulatorType* accum_ptr, int depth) {
4836 const int depth_cells = static_cast<int>(depth / Format::kDepth);
4837
4838 // The outer loop is over the depth dimension.
4839 for (int dc = 0; dc < depth_cells; dc++) {
4840 // The next two loops are over cells of the Lhs (stacked vertically),
4841 // and over cells of the Rhs (stacked horizontally).
4842 for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
4843 const OperandType* lhs_cell_ptr =
4844 lhs_ptr + (dc * Format::Lhs::kCells + rc) *
4845 Format::Lhs::Cell::kWidth * Format::kDepth;
4846 for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
4847 const OperandType* rhs_cell_ptr =
4848 rhs_ptr + (dc * Format::Rhs::kCells + cc) *
4849 Format::Rhs::Cell::kWidth * Format::kDepth;
4850
4851 // Now we are inside one cell of the Lhs and inside one cell
4852 // of the Rhs, so the remaining inner loops are just
4853 // traditional three loops of matrix multiplication.
4854 for (int di = 0; di < Format::kDepth; di++) {
4855 for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) {
4856 for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) {
4857 const OperandType* lhs_coeff_ptr =
4858 lhs_cell_ptr +
4859 OffsetIntoCell<typename Format::Lhs::Cell>(ri, di);
4860 const OperandType* rhs_coeff_ptr =
4861 rhs_cell_ptr +
4862 OffsetIntoCell<typename Format::Rhs::Cell>(ci, di);
4863 AccumulatorType* accumulator_coeff_ptr =
4864 accum_ptr + (ri + rc * Format::Lhs::Cell::kWidth) +
4865 (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows;
4866 *accumulator_coeff_ptr += AccumulatorType(*lhs_coeff_ptr) *
4867 AccumulatorType(*rhs_coeff_ptr);
4868 }
4869 }
4870 }
4871 }
4872 }
4873 }
4874 }
4875 };
4876
4877 // END code copied from gemmlowp/internal/kernel_reference.h
4878
4879 template <typename DataType>
4880 class CacheLineAlignedBuffer {
4881 public:
CacheLineAlignedBuffer(std::size_t size)4882 CacheLineAlignedBuffer(std::size_t size) : size_(size) {
4883 data_ = nullptr;
4884 // Adds a few bytes of padding here, because the 64-bit 'A57' kernel
4885 // reads one iteration past the end the buffer, causing a crash on iOS.
4886 int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
4887 size_ * sizeof(DataType) + 16);
4888 (void)res;
4889 }
4890
~CacheLineAlignedBuffer()4891 ~CacheLineAlignedBuffer() { free(data_); }
4892
data() const4893 const DataType* data() const { return data_; }
data()4894 DataType* data() { return data_; }
4895
size() const4896 std::size_t size() const { return size_; }
4897
4898 private:
4899 const std::size_t size_;
4900 DataType* data_;
4901 };
4902
4903 template <typename DataType>
FillRandom(CacheLineAlignedBuffer<DataType> * buffer)4904 void FillRandom(CacheLineAlignedBuffer<DataType>* buffer) {
4905 static std::mt19937 generator(0);
4906 // 100 is smaller than any nonzero bound of the range of any data type.
4907 const DataType kMaxVal = DataType(100);
4908 const DataType kMinVal =
4909 std::is_signed<DataType>::value ? -kMaxVal : DataType(0);
4910 std::uniform_real_distribution<float> dist(kMinVal, kMaxVal);
4911 for (std::size_t i = 0; i < buffer->size(); i++) {
4912 buffer->data()[i] = DataType(dist(generator));
4913 }
4914 }
4915
4916 template <typename DataType>
FillZero(CacheLineAlignedBuffer<DataType> * buffer)4917 void FillZero(CacheLineAlignedBuffer<DataType>* buffer) {
4918 for (std::size_t i = 0; i < buffer->size(); i++) {
4919 buffer->data()[i] = DataType(0);
4920 }
4921 }
4922
4923 template <typename DataType>
Copy(CacheLineAlignedBuffer<DataType> * dst,const CacheLineAlignedBuffer<DataType> & src)4924 void Copy(CacheLineAlignedBuffer<DataType>* dst,
4925 const CacheLineAlignedBuffer<DataType>& src) {
4926 assert(dst->size() == src.size());
4927 memcpy(dst->data(), src.data(), src.size() * sizeof(DataType));
4928 }
4929
4930 template <typename DataType>
PrintMatrix(int rows,int cols,int rowstride,int colstride,const DataType * data)4931 void PrintMatrix(int rows, int cols, int rowstride, int colstride,
4932 const DataType* data) {
4933 for (int r = 0; r < rows; r++) {
4934 for (int c = 0; c < cols; c++) {
4935 std::cerr << double(data[r * rowstride + c * colstride]) << " ";
4936 }
4937 std::cerr << std::endl;
4938 }
4939 std::cerr << std::endl;
4940 }
4941
4942 template <typename DataType>
approx_equals(DataType a,DataType b)4943 bool approx_equals(DataType a, DataType b) {
4944 return a == b;
4945 }
4946
4947 template <>
approx_equals(float a,float b)4948 bool approx_equals(float a, float b) {
4949 if (!a && !b) {
4950 return true;
4951 }
4952 // 1e-1 is very coarse accuracy, we should switch to an overall L2 metric
4953 // and tighten the tolerance on that metric.
4954 return std::abs(a - b) < 1e-1f * std::min(std::abs(a), std::abs(b));
4955 }
4956
4957 template <typename Kernel>
test_kernel(int depth,const char * kernel_name)4958 void test_kernel(int depth, const char* kernel_name) {
4959 typedef typename Kernel::OperandType OperandType;
4960 typedef typename Kernel::AccumulatorType AccumulatorType;
4961 typedef typename Kernel::Format Format;
4962 static const int kLhsWidth = Format::Lhs::kWidth;
4963 static const int kRhsWidth = Format::Rhs::kWidth;
4964
4965 typedef ReferenceKernel<OperandType, AccumulatorType, Format> ReferenceKernel;
4966
4967 CacheLineAlignedBuffer<OperandType> lhs(kLhsWidth * depth);
4968 CacheLineAlignedBuffer<OperandType> rhs(kRhsWidth * depth);
4969 CacheLineAlignedBuffer<AccumulatorType> accum_initial(kLhsWidth * kRhsWidth);
4970 CacheLineAlignedBuffer<AccumulatorType> accum(kLhsWidth * kRhsWidth);
4971 CacheLineAlignedBuffer<AccumulatorType> accum_reference(kLhsWidth *
4972 kRhsWidth);
4973
4974 FillRandom(&lhs);
4975 FillRandom(&rhs);
4976 FillRandom(&accum_initial);
4977 Copy(&accum, accum_initial);
4978 Copy(&accum_reference, accum_initial);
4979
4980 ReferenceKernel::Run(lhs.data(), rhs.data(), accum_reference.data(), depth);
4981 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth);
4982
4983 for (int l = 0; l < kLhsWidth; l++) {
4984 for (int r = 0; r < kRhsWidth; r++) {
4985 const int index = l + kLhsWidth * r;
4986 if (!approx_equals(accum.data()[index], accum_reference.data()[index])) {
4987 std::cerr << "Arithmetic error in kernel:" << std::endl
4988 << " " << kernel_name << std::endl
4989 << "Wrong accumulator for depth=" << depth << ", "
4990 << "at l = " << l << ", r = " << r << std::endl;
4991 std::cerr << "reference value: " << accum_reference.data()[index]
4992 << std::endl;
4993 std::cerr << "actual value: " << accum.data()[index] << std::endl;
4994 if (depth <= 16) {
4995 std::cerr << "LHS matrix:" << std::endl;
4996 PrintMatrix(kLhsWidth, depth, 1, kLhsWidth, lhs.data());
4997 std::cerr << "RHS matrix:" << std::endl;
4998 PrintMatrix(depth, kRhsWidth, kRhsWidth, 1, rhs.data());
4999 std::cerr << "Initial Accumulator matrix:" << std::endl;
5000 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum_initial.data());
5001 std::cerr << "Reference Accumulator matrix:" << std::endl;
5002 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth,
5003 accum_reference.data());
5004 std::cerr << "Actual Accumulator matrix:" << std::endl;
5005 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum.data());
5006 }
5007 abort();
5008 }
5009 }
5010 }
5011 }
5012
5013 template <typename Kernel>
ops(int depth)5014 int ops(int depth) {
5015 // 2x the number of multiply-accumulate scalar ops.
5016 return 2 * Kernel::Format::Lhs::kWidth * Kernel::Format::Rhs::kWidth * depth;
5017 }
5018
5019 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)5020 Integer RoundDown(Integer i) {
5021 return i - (i % Modulus);
5022 }
5023
CacheSizeInKB()5024 int CacheSizeInKB() {
5025 static const char* cache_size_k_env = getenv("CACHE_SIZE_KB");
5026 static const int cache_size_k =
5027 cache_size_k_env ? atoi(cache_size_k_env) : kDefaultCacheSizeK;
5028 return cache_size_k;
5029 }
5030
5031 template <typename Kernel>
BenchmarkDepthToFitInCache()5032 int BenchmarkDepthToFitInCache() {
5033 const int cache_size_bytes = 1024 * CacheSizeInKB();
5034
5035 // Subtract the typical size of a few cache lines, so
5036 // we don't need to worry too hard about e.g. some stack data.
5037 const int conservative_cache_size_bytes =
5038 cache_size_bytes - 2 * kCacheLineSize;
5039
5040 // We will subtract the memory occupied by accumulators.
5041 typedef typename Kernel::AccumulatorType AccumulatorType;
5042 const int kAccumulatorBytes = sizeof(AccumulatorType) *
5043 Kernel::Format::Lhs::kWidth *
5044 Kernel::Format::Rhs::kWidth;
5045
5046 // Compute the depth.
5047 typedef typename Kernel::OperandType OperandType;
5048 const int kBytesPerUnitOfDepth =
5049 sizeof(OperandType) *
5050 (Kernel::Format::Lhs::kWidth + Kernel::Format::Rhs::kWidth);
5051 const int unrounded_depth =
5052 (conservative_cache_size_bytes - kAccumulatorBytes) /
5053 kBytesPerUnitOfDepth;
5054
5055 // Cap depth, to avoid unfairly favoring narrower kernels
5056 const int kMaxDepth = 1024;
5057 const int clamped_unrounded_depth = std::min(kMaxDepth, unrounded_depth);
5058
5059 // Round depth down to a multiple of cache line size, which helps because
5060 // our kernels may crash if depth is not a multiple of the number of
5061 // depth level that they want to
5062 // handle at each loop iteration, and we don't want to require kernels
5063 // to be more complex. Currently all kernels process 1, 2 or 8 levels of
5064 // depth at a time. The main reason why that might increase in the future
5065 // is if registers get wider, but I don't suppose that register could
5066 // ever get wider than cache lines.
5067 return RoundDown<kCacheLineSize>(clamped_unrounded_depth);
5068 }
5069
current_time_in_seconds()5070 double current_time_in_seconds() {
5071 timespec t;
5072 clock_gettime(CLOCK_REALTIME, &t);
5073 return t.tv_sec + 1e-9 * t.tv_nsec;
5074 }
5075
5076 template <typename Kernel>
benchmark(int depth)5077 double benchmark(int depth) {
5078 // Minimum duration for this benchmark to run. If the workload finishes
5079 // sooner, we retry with double the number of iterations.
5080 static const double min_benchmark_time_in_seconds = 1.0;
5081
5082 typedef typename Kernel::OperandType OperandType;
5083 typedef typename Kernel::AccumulatorType AccumulatorType;
5084
5085 CacheLineAlignedBuffer<OperandType> lhs(Kernel::Format::Lhs::kWidth * depth);
5086 CacheLineAlignedBuffer<OperandType> rhs(Kernel::Format::Rhs::kWidth * depth);
5087 CacheLineAlignedBuffer<AccumulatorType> accum(Kernel::Format::Lhs::kWidth *
5088 Kernel::Format::Rhs::kWidth);
5089
5090 for (std::uint64_t iters_at_a_time = 1;; iters_at_a_time *= 2) {
5091 const double t_start = current_time_in_seconds();
5092 for (std::uint64_t i = 0; i < iters_at_a_time; i++) {
5093 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth);
5094 }
5095 const double t_end = current_time_in_seconds();
5096 const double elapsed = t_end - t_start;
5097 if (elapsed > min_benchmark_time_in_seconds) {
5098 return iters_at_a_time * ops<Kernel>(depth) / elapsed;
5099 }
5100 }
5101 }
5102
5103 template <typename Kernel>
benchmark_and_print_results(const char * kernel_name)5104 void benchmark_and_print_results(const char* kernel_name) {
5105 if (getenv("BENCHMARK_KERNEL")) {
5106 if (strcmp(getenv("BENCHMARK_KERNEL"), kernel_name)) {
5107 return;
5108 }
5109 }
5110 const int kKernelDepth = Kernel::Format::kDepth;
5111 for (int depth = kKernelDepth; depth <= 1024; depth += kKernelDepth) {
5112 test_kernel<Kernel>(depth, kernel_name);
5113 }
5114
5115 if (getenv("BENCHMARK_ALL_DEPTHS")) {
5116 for (int depth = kKernelDepth;
5117 depth <= BenchmarkDepthToFitInCache<Kernel>(); depth *= 2) {
5118 std::cout << kernel_name << "," << depth << ","
5119 << benchmark<Kernel>(depth) * 1e-9f << std::endl;
5120 }
5121 } else {
5122 const int depth = BenchmarkDepthToFitInCache<Kernel>();
5123 std::cout << kernel_name << "," << benchmark<Kernel>(depth) * 1e-9f
5124 << std::endl;
5125 }
5126 }
5127
5128 #define BENCHMARK(Kernel) \
5129 do { \
5130 benchmark_and_print_results<Kernel>(#Kernel); \
5131 } while (false)
5132
main()5133 int main() {
5134 if (getenv("BENCHMARK_ALL_DEPTHS")) {
5135 std::cout << "kernel,depth,Gop/s" << std::endl;
5136 } else {
5137 std::cout << "kernel,Gop/s" << std::endl;
5138 }
5139
5140 #ifdef __arm__
5141 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
5142 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
5143 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators);
5144 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
5145 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand);
5146 BENCHMARK(NEON_32bit_GEMM_Int32_WithScalar);
5147 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar);
5148 #ifdef __ARM_FEATURE_FMA
5149 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar);
5150 #endif
5151 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithScalar);
5152 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_intrinsics);
5153 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53);
5154 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53_depth2);
5155 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_Rotating);
5156 #ifdef __ARM_FEATURE_FMA
5157 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_Rotating);
5158 #endif
5159 #endif
5160
5161 #ifdef __aarch64__
5162 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
5163 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
5164 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
5165 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
5166 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57);
5167 #ifdef __ARM_FEATURE_DOTPROD
5168 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
5169 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
5170 #endif
5171 BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
5172 BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
5173 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar);
5174 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_intrinsics);
5175 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A57);
5176 #ifndef __APPLE__
5177 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53);
5178 #endif
5179 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1);
5180 #endif
5181
5182 #ifdef __mips
5183 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics);
5184 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly);
5185 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2);
5186 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics);
5187 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly);
5188 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2);
5189 #endif
5190
5191 return 0;
5192 }
5193