1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #ifdef __aarch64__
25
26 #include "arm_gemm.hpp"
27 #include "gemm_common.hpp"
28 #include "gemm_implementation.hpp"
29 #include "gemm_interleaved.hpp"
30 #include "gemm_interleaved_pretransposed_2d.hpp"
31 #include "gemm_hybrid.hpp"
32 #include "gemm_hybrid_indirect.hpp"
33
34 #include "kernels/a64_gemm_u16_8x12.hpp"
35 #include "kernels/a64_gemm_u8_4x4.hpp"
36 #include "kernels/a64_gemm_u8_8x12.hpp"
37 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
38 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
39 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
40 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
41
42 #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
43 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
44 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
45 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
46
47 namespace arm_gemm {
48
49 static const GemmImplementation<uint8_t, uint32_t> gemm_u8_methods[] = {
50 #ifdef __ARM_FEATURE_SVE
51 #ifdef MMLA_INT8
52 {
53 GemmMethod::GEMM_INTERLEAVED,
54 "sve_interleaved_u8u32_mmla_8x3VL",
__anonb2d5b2910102() 55 [](const GemmArgs &args) { return (args._Ksize>8); },
56 nullptr,
__anonb2d5b2910202() 57 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>(args); }
58 },
59 #endif
60 {
61 GemmMethod::GEMM_HYBRID,
62 "smallK_hybrid_u8u32_dot_8x1VL",
__anonb2d5b2910302() 63 [](const GemmArgs &args) { return args._Ksize<=64 && !args._indirect_input; },
64 nullptr,
__anonb2d5b2910402() 65 [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint32_t>(args); }
66 },
67 {
68 GemmMethod::GEMM_HYBRID,
69 "sve_hybrid_u8u32_dot_6x4VL",
70 nullptr,
__anonb2d5b2910502() 71 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
__anonb2d5b2910602() 72 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>(args); }
73 },
74 {
75 GemmMethod::GEMM_INTERLEAVED,
76 "sve_interleaved_u8u32_dot_8x3VL",
__anonb2d5b2910702() 77 [](const GemmArgs &args) { return (args._Ksize>4); },
78 nullptr,
__anonb2d5b2910802() 79 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>(args); }
80 },
81 #endif
82 #ifdef MMLA_INT8
83 {
84 GemmMethod::GEMM_INTERLEAVED,
85 "a64_interleaved_u8u32_mmla_8x12",
__anonb2d5b2910902() 86 [](const GemmArgs &args) { return (args._Ksize>8); },
87 nullptr,
__anonb2d5b2910a02() 88 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>(args); }
89 },
90 #endif
91 {
92 GemmMethod::GEMM_HYBRID,
93 "a64_smallK_hybrid_u8u32_dot_8x4",
__anonb2d5b2910b02() 94 [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
95 nullptr,
__anonb2d5b2910c02() 96 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
97 },
98 {
99 GemmMethod::GEMM_HYBRID,
100 "a64_smallK_hybrid_u8u32_dot_6x4",
__anonb2d5b2910d02() 101 [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
102 nullptr,
__anonb2d5b2910e02() 103 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
104 },
105 {
106 GemmMethod::GEMM_INTERLEAVED,
107 "a64_gemm_u16_8x12",
108 nullptr,
__anonb2d5b2910f02() 109 [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53; },
__anonb2d5b2911002() 110 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u16_8x12, uint8_t, uint32_t>(args); },
111 },
112 {
113 GemmMethod::GEMM_HYBRID,
114 "a64_hybrid_u8u32_dot_6x16",
__anonb2d5b2911102() 115 [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anonb2d5b2911202() 116 [](const GemmArgs &args) { return args._Nsize<=256 && args._Ksize>128; },
__anonb2d5b2911302() 117 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>(args); }
118 },
119 {
120 GemmMethod::GEMM_INTERLEAVED,
121 "a64_gemm_u8_8x12",
__anonb2d5b2911402() 122 [](const GemmArgs &args) { return args._ci->has_dotprod(); },
123 nullptr,
__anonb2d5b2911502() 124 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>(args); }
125 },
126 {
127 GemmMethod::GEMM_INTERLEAVED,
128 "a64_gemm_u8_4x4",
129 nullptr,
130 nullptr,
__anonb2d5b2911602() 131 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>(args); }
132 },
133 {
134 GemmMethod::DEFAULT,
135 "",
136 nullptr,
137 nullptr,
138 nullptr
139 }
140 };
141
142 template<>
gemm_implementation_list()143 const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, uint32_t>() {
144 return gemm_u8_methods;
145 }
146
147 /* Explicitly instantiate the external functions for these types. */
148 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
149 template KernelDescription get_gemm_method<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
150 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
151
152 } // namespace arm_gemm
153
154 #endif // __aarch64__
155