• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 
28 #include "kernels/a64_gemm_u16_8x12.hpp"
29 #include "kernels/a64_gemm_u8_4x4.hpp"
30 #include "kernels/a64_gemm_u8_8x12.hpp"
31 #include "kernels/a64_hybrid_u8qa_dot_4x16.hpp"
32 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
33 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
34 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
35 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
36 
37 #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
38 #include "kernels/sve_hybrid_u8qa_dot_4x4VL.hpp"
39 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
40 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
41 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
42 
43 #include "gemm_hybrid_indirect.hpp"
44 #include "gemm_hybrid_quantized.hpp"
45 #include "gemm_hybrid_quantized_inline.hpp"
46 #include "gemm_interleaved.hpp"
47 #include "quantize_wrapper.hpp"
48 
49 namespace arm_gemm {
50 
51 static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
52 {
53 #ifdef __ARM_FEATURE_SVE
54 #ifdef MMLA_INT8
55 {
56     GemmMethod::GEMM_INTERLEAVED,
57     "sve_interleaved_u8u32_mmla_8x3VL",
__anon949dab220102() 58     [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
59     nullptr,
__anon949dab220202() 60     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>(args, qp); }
61 },
62 #endif
63 {
64     GemmMethod::GEMM_HYBRID_QUANTIZED,
65     "sve_smallK_hybrid_u8u32_dot_8x1VL",
__anon949dab220302() 66     [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._indirect_input; },
67     nullptr,
__anon949dab220402() 68     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint8_t>(args, qp); }
69 },
70 #ifdef SVE2 // Requantizing kernels include some SVE2 only instructions (SQRDMULH, SRSHL)
71 {
72     GemmMethod::GEMM_HYBRID,
73     "sve_hybrid_u8qa_dot_4x4VL",
__anon949dab220502() 74     [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp); },
75     nullptr,
__anon949dab220602() 76     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
77 },
78 #endif
79 {
80     GemmMethod::GEMM_HYBRID,
81     "sve_hybrid_u8u32_dot_6x4VL",
82     nullptr,
83     nullptr,
__anon949dab220702() 84     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
85 },
86 {
87     GemmMethod::GEMM_INTERLEAVED,
88     "sve_interleaved_u8u32_dot_8x3VL",
__anon949dab220802() 89     [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>4); },
90     nullptr,
__anon949dab220902() 91     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>(args, qp); }
92 },
93 #endif
94 #ifdef MMLA_INT8
95 {
96     GemmMethod::GEMM_INTERLEAVED,
97     "a64_interleaved_u8u32_mmla_8x12",
__anon949dab220a02() 98     [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
99     nullptr,
__anon949dab220b02() 100     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>(args, qp); }
101 },
102 #endif
103 {
104     GemmMethod::GEMM_HYBRID_QUANTIZED,
105     "a64_smallK_hybrid_u8u32_dot_8x4",
__anon949dab220c02() 106     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
107     nullptr,
__anon949dab220d02() 108     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint8_t>(args, qp); }
109 },
110 {
111     GemmMethod::GEMM_HYBRID_QUANTIZED,
112     "a64_smallK_hybrid_u8u32_dot_6x4",
__anon949dab220e02() 113     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
114     nullptr,
__anon949dab220f02() 115     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint8_t>(args, qp); }
116 },
117 {
118     GemmMethod::GEMM_INTERLEAVED,
119     "a64_gemm_u16_8x12",
120     nullptr,
__anon949dab221002() 121     [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() == CPUModel::A53; },
__anon949dab221102() 122     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u16_8x12, uint8_t, uint8_t>(args, qp); },
123 },
124 {
125     GemmMethod::GEMM_HYBRID,
126     "a64_hybrid_u8qa_dot_4x16",
__anon949dab221202() 127     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
__anon949dab221302() 128     [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
__anon949dab221402() 129     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
130 },
131 {
132     GemmMethod::GEMM_HYBRID,
133     "a64_hybrid_u8u32_dot_6x16",
__anon949dab221502() 134     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
__anon949dab221602() 135     [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
__anon949dab221702() 136     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
137 },
138 {
139     GemmMethod::GEMM_INTERLEAVED,
140     "a64_gemm_u8_8x12",
__anon949dab221802() 141     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
142     nullptr,
__anon949dab221902() 143     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>(args, qp); }
144 },
145 {
146     GemmMethod::GEMM_INTERLEAVED,
147     "a64_gemm_u8_4x4",
148     nullptr,
149     nullptr,
__anon949dab221a02() 150     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>(args, qp); }
151 },
152 {
153     GemmMethod::QUANTIZE_WRAPPER,
154     "quantized_wrapper",
__anon949dab221b02() 155     [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
156     nullptr,
__anon949dab221c02() 157     [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
158 },
159 {
160     GemmMethod::DEFAULT,
161     "",
162     nullptr,
163     nullptr,
164     nullptr
165 }
166 };
167 
168 template<>
gemm_implementation_list()169 const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_list<uint8_t, uint8_t, Requantize32>() {
170     return gemm_quint8_methods;
171 }
172 
173 template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
174 template KernelDescription get_gemm_method<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
175 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
176 
177 } // namespace arm_gemm
178 
179 #endif // __aarch64__
180