• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 #include "gemm_common.hpp"
28 #include "gemm_hybrid.hpp"
29 #include "gemm_hybrid_indirect.hpp"
30 #include "gemm_implementation.hpp"
31 #include "gemm_interleaved.hpp"
32 
33 #include "kernels/a64_gemm_s16_8x12.hpp"
34 #include "kernels/a64_gemm_s8_8x12.hpp"
35 #include "kernels/a64_gemm_s8_4x4.hpp"
36 #include "kernels/a64_hybrid_s8s32_dot_6x16.hpp"
37 #include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
38 #include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
39 #include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
40 
41 #include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp"
42 #include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
43 #include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
44 #include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
45 
46 namespace arm_gemm {
47 
48 static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = {
49 #ifdef __ARM_FEATURE_SVE
50 #ifdef MMLA_INT8
51 {
52     GemmMethod::GEMM_INTERLEAVED,
53     "sve_interleaved_s8s32_mmla_8x3VL",
__anon6d71441c0102() 54     [](const GemmArgs &args) { return (args._Ksize>8); },
55     nullptr,
__anon6d71441c0202() 56     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int32_t>(args); }
57 },
58 #endif
59 {
60     GemmMethod::GEMM_HYBRID,
61     "sve_smallK_hybrid_s8s32_dot_8x1VL",
__anon6d71441c0302() 62     [](const GemmArgs &args) { return args._Ksize<=64 && !args._indirect_input; },
63     nullptr,
__anon6d71441c0402() 64     [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_s8s32_dot_8x1VL, int8_t, int32_t>(args); }
65 },
66 {
67     GemmMethod::GEMM_HYBRID,
68     "sve_hybrid_s8s32_dot_6x4VL",
__anon6d71441c0502() 69     [](const GemmArgs &args) { return args._Ksize>=16; },
__anon6d71441c0602() 70     [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
__anon6d71441c0702() 71     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int32_t>(args); }
72 },
73 {
74     GemmMethod::GEMM_INTERLEAVED,
75     "sve_interleaved_s8s32_dot_8x3VL",
__anon6d71441c0802() 76     [](const GemmArgs &args) { return (args._Ksize>4); },
77     nullptr,
__anon6d71441c0902() 78     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int32_t>(args); }
79 },
80 #endif // SVE
81 #ifdef MMLA_INT8
82 {
83     GemmMethod::GEMM_INTERLEAVED,
84     "a64_interleaved_s8s32_mmla_8x12",
__anon6d71441c0a02() 85     [](const GemmArgs &args) { return (args._Ksize>8); },
86     nullptr,
__anon6d71441c0b02() 87     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int32_t>(args); }
88 },
89 #endif
90 {
91     GemmMethod::GEMM_HYBRID,
92     "a64_smallK_hybrid_s8s32_dot_8x4",
__anon6d71441c0c02() 93     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
94     nullptr,
__anon6d71441c0d02() 95     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
96 },
97 {
98     GemmMethod::GEMM_HYBRID,
99     "a64_smallK_hybrid_s8s32_dot_6x4",
__anon6d71441c0e02() 100     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
101     nullptr,
__anon6d71441c0f02() 102     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
103 },
104 {
105     GemmMethod::GEMM_INTERLEAVED,
106     "a64_gemm_s16_8x12",
107     nullptr,
__anon6d71441c1002() 108     [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53 && args._Ksize>4; },
__anon6d71441c1102() 109     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s16_8x12, int8_t, int32_t>(args); },
110 },
111 {
112     GemmMethod::GEMM_HYBRID,
113     "a64_hybrid_s8s32_dot_6x16",
__anon6d71441c1202() 114     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anon6d71441c1302() 115     [](const GemmArgs &args) { return args._Nsize<=256 && args._Ksize>128; },
__anon6d71441c1402() 116     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int32_t>(args); }
117 },
118 {
119     GemmMethod::GEMM_INTERLEAVED,
120     "a64_gemm_s8_8x12",
__anon6d71441c1502() 121     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
122     nullptr,
__anon6d71441c1602() 123     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s8_8x12, int8_t, int32_t>(args); }
124 },
125 {
126     GemmMethod::GEMM_INTERLEAVED,
127     "a64_gemm_s8_4x4",
128     nullptr,
129     nullptr,
__anon6d71441c1702() 130     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s8_4x4, int8_t, int32_t>(args); }
131 },
132 {
133     GemmMethod::DEFAULT,
134     "",
135     nullptr,
136     nullptr,
137     nullptr
138 }
139 };
140 
141 template<>
gemm_implementation_list()142 const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int32_t>() {
143     return gemm_s8_methods;
144 }
145 
146 /* Explicitly instantiate the external functions for these types. */
147 template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
148 template KernelDescription get_gemm_method<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
149 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
150 
151 } // namespace arm_gemm
152 
153 #endif // __aarch64__
154