1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #ifdef __aarch64__
25
26 #include "arm_gemm.hpp"
27 #include "gemm_common.hpp"
28 #include "gemm_hybrid.hpp"
29 #include "gemm_hybrid_indirect.hpp"
30 #include "gemm_implementation.hpp"
31 #include "gemm_interleaved.hpp"
32
33 #include "kernels/a64_gemm_s16_8x12.hpp"
34 #include "kernels/a64_gemm_s8_8x12.hpp"
35 #include "kernels/a64_gemm_s8_4x4.hpp"
36 #include "kernels/a64_hybrid_s8s32_dot_6x16.hpp"
37 #include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
38 #include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
39 #include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
40
41 #include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp"
42 #include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
43 #include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
44 #include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
45
46 namespace arm_gemm {
47
48 static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = {
49 #ifdef __ARM_FEATURE_SVE
50 #ifdef MMLA_INT8
51 {
52 GemmMethod::GEMM_INTERLEAVED,
53 "sve_interleaved_s8s32_mmla_8x3VL",
__anon6d71441c0102() 54 [](const GemmArgs &args) { return (args._Ksize>8); },
55 nullptr,
__anon6d71441c0202() 56 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int32_t>(args); }
57 },
58 #endif
59 {
60 GemmMethod::GEMM_HYBRID,
61 "sve_smallK_hybrid_s8s32_dot_8x1VL",
__anon6d71441c0302() 62 [](const GemmArgs &args) { return args._Ksize<=64 && !args._indirect_input; },
63 nullptr,
__anon6d71441c0402() 64 [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_s8s32_dot_8x1VL, int8_t, int32_t>(args); }
65 },
66 {
67 GemmMethod::GEMM_HYBRID,
68 "sve_hybrid_s8s32_dot_6x4VL",
__anon6d71441c0502() 69 [](const GemmArgs &args) { return args._Ksize>=16; },
__anon6d71441c0602() 70 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
__anon6d71441c0702() 71 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int32_t>(args); }
72 },
73 {
74 GemmMethod::GEMM_INTERLEAVED,
75 "sve_interleaved_s8s32_dot_8x3VL",
__anon6d71441c0802() 76 [](const GemmArgs &args) { return (args._Ksize>4); },
77 nullptr,
__anon6d71441c0902() 78 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int32_t>(args); }
79 },
80 #endif // SVE
81 #ifdef MMLA_INT8
82 {
83 GemmMethod::GEMM_INTERLEAVED,
84 "a64_interleaved_s8s32_mmla_8x12",
__anon6d71441c0a02() 85 [](const GemmArgs &args) { return (args._Ksize>8); },
86 nullptr,
__anon6d71441c0b02() 87 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int32_t>(args); }
88 },
89 #endif
90 {
91 GemmMethod::GEMM_HYBRID,
92 "a64_smallK_hybrid_s8s32_dot_8x4",
__anon6d71441c0c02() 93 [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
94 nullptr,
__anon6d71441c0d02() 95 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
96 },
97 {
98 GemmMethod::GEMM_HYBRID,
99 "a64_smallK_hybrid_s8s32_dot_6x4",
__anon6d71441c0e02() 100 [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
101 nullptr,
__anon6d71441c0f02() 102 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
103 },
104 {
105 GemmMethod::GEMM_INTERLEAVED,
106 "a64_gemm_s16_8x12",
107 nullptr,
__anon6d71441c1002() 108 [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53 && args._Ksize>4; },
__anon6d71441c1102() 109 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s16_8x12, int8_t, int32_t>(args); },
110 },
111 {
112 GemmMethod::GEMM_HYBRID,
113 "a64_hybrid_s8s32_dot_6x16",
__anon6d71441c1202() 114 [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anon6d71441c1302() 115 [](const GemmArgs &args) { return args._Nsize<=256 && args._Ksize>128; },
__anon6d71441c1402() 116 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int32_t>(args); }
117 },
118 {
119 GemmMethod::GEMM_INTERLEAVED,
120 "a64_gemm_s8_8x12",
__anon6d71441c1502() 121 [](const GemmArgs &args) { return args._ci->has_dotprod(); },
122 nullptr,
__anon6d71441c1602() 123 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s8_8x12, int8_t, int32_t>(args); }
124 },
125 {
126 GemmMethod::GEMM_INTERLEAVED,
127 "a64_gemm_s8_4x4",
128 nullptr,
129 nullptr,
__anon6d71441c1702() 130 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_s8_4x4, int8_t, int32_t>(args); }
131 },
132 {
133 GemmMethod::DEFAULT,
134 "",
135 nullptr,
136 nullptr,
137 nullptr
138 }
139 };
140
141 template<>
gemm_implementation_list()142 const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int32_t>() {
143 return gemm_s8_methods;
144 }
145
146 /* Explicitly instantiate the external functions for these types. */
147 template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
148 template KernelDescription get_gemm_method<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
149 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
150
151 } // namespace arm_gemm
152
153 #endif // __aarch64__
154