• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 #include "gemm_common.hpp"
28 #include "gemm_implementation.hpp"
29 #include "gemm_interleaved.hpp"
30 #include "gemm_interleaved_pretransposed_2d.hpp"
31 #include "gemm_hybrid.hpp"
32 #include "gemm_hybrid_indirect.hpp"
33 
34 #include "kernels/a64_gemm_u16_8x12.hpp"
35 #include "kernels/a64_gemm_u8_4x4.hpp"
36 #include "kernels/a64_gemm_u8_8x12.hpp"
37 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
38 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
39 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
40 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
41 
42 #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
43 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
44 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
45 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
46 
47 namespace arm_gemm {
48 
49 static const GemmImplementation<uint8_t, uint32_t> gemm_u8_methods[] = {
50 #ifdef __ARM_FEATURE_SVE
51 #ifdef MMLA_INT8
52 {
53     GemmMethod::GEMM_INTERLEAVED,
54     "sve_interleaved_u8u32_mmla_8x3VL",
__anonb2d5b2910102() 55     [](const GemmArgs &args) { return (args._Ksize>8); },
56     nullptr,
__anonb2d5b2910202() 57     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>(args); }
58 },
59 #endif
60 {
61     GemmMethod::GEMM_HYBRID,
62     "smallK_hybrid_u8u32_dot_8x1VL",
__anonb2d5b2910302() 63     [](const GemmArgs &args) { return args._Ksize<=64 && !args._indirect_input; },
64     nullptr,
__anonb2d5b2910402() 65     [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint32_t>(args); }
66 },
67 {
68     GemmMethod::GEMM_HYBRID,
69     "sve_hybrid_u8u32_dot_6x4VL",
70     nullptr,
__anonb2d5b2910502() 71     [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
__anonb2d5b2910602() 72     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>(args); }
73 },
74 {
75     GemmMethod::GEMM_INTERLEAVED,
76     "sve_interleaved_u8u32_dot_8x3VL",
__anonb2d5b2910702() 77     [](const GemmArgs &args) { return (args._Ksize>4); },
78     nullptr,
__anonb2d5b2910802() 79     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>(args); }
80 },
81 #endif
82 #ifdef MMLA_INT8
83 {
84     GemmMethod::GEMM_INTERLEAVED,
85     "a64_interleaved_u8u32_mmla_8x12",
__anonb2d5b2910902() 86     [](const GemmArgs &args) { return (args._Ksize>8); },
87     nullptr,
__anonb2d5b2910a02() 88     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>(args); }
89 },
90 #endif
91 {
92     GemmMethod::GEMM_HYBRID,
93     "a64_smallK_hybrid_u8u32_dot_8x4",
__anonb2d5b2910b02() 94     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
95     nullptr,
__anonb2d5b2910c02() 96     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
97 },
98 {
99     GemmMethod::GEMM_HYBRID,
100     "a64_smallK_hybrid_u8u32_dot_6x4",
__anonb2d5b2910d02() 101     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
102     nullptr,
__anonb2d5b2910e02() 103     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
104 },
105 {
106     GemmMethod::GEMM_INTERLEAVED,
107     "a64_gemm_u16_8x12",
108     nullptr,
__anonb2d5b2910f02() 109     [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53; },
__anonb2d5b2911002() 110     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u16_8x12, uint8_t, uint32_t>(args); },
111 },
112 {
113     GemmMethod::GEMM_HYBRID,
114     "a64_hybrid_u8u32_dot_6x16",
__anonb2d5b2911102() 115     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anonb2d5b2911202() 116     [](const GemmArgs &args) { return args._Nsize<=256 && args._Ksize>128; },
__anonb2d5b2911302() 117     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>(args); }
118 },
119 {
120     GemmMethod::GEMM_INTERLEAVED,
121     "a64_gemm_u8_8x12",
__anonb2d5b2911402() 122     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
123     nullptr,
__anonb2d5b2911502() 124     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>(args); }
125 },
126 {
127     GemmMethod::GEMM_INTERLEAVED,
128     "a64_gemm_u8_4x4",
129     nullptr,
130     nullptr,
__anonb2d5b2911602() 131     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>(args); }
132 },
133 {
134     GemmMethod::DEFAULT,
135     "",
136     nullptr,
137     nullptr,
138     nullptr
139 }
140 };
141 
142 template<>
gemm_implementation_list()143 const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, uint32_t>() {
144     return gemm_u8_methods;
145 }
146 
147 /* Explicitly instantiate the external functions for these types. */
148 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
149 template KernelDescription get_gemm_method<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
150 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
151 
152 } // namespace arm_gemm
153 
154 #endif // __aarch64__
155