1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_gemm.hpp"
25 #include "bfloat.hpp"
26 #include "gemm_common.hpp"
27 #include "gemm_hybrid.hpp"
28 #include "gemm_hybrid_indirect.hpp"
29 #include "gemm_implementation.hpp"
30 #include "gemm_interleaved.hpp"
31 #include "gemv_batched.hpp"
32 #include "gemv_pretransposed.hpp"
33
34 #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
35 #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
36 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
37 #include "kernels/a64_sgemm_8x12.hpp"
38 #include "kernels/a32_sgemm_8x6.hpp"
39 #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
40 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
41 #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
42
43 namespace arm_gemm {
44
45 static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
46 {
47 #ifdef V8P6_BF
48 #ifdef __ARM_FEATURE_SVE
49 { // gemm_bf16_interleaved
50 GemmMethod::GEMM_INTERLEAVED,
51 "sve_interleaved_bf16fp32_mmla_8x3VL",
__anonf364b8c80102() 52 [](const GemmArgs &args) { return (args._Ksize>4); },
53 nullptr,
__anonf364b8c80202() 54 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
55 },
56 {
57 GemmMethod::GEMM_HYBRID,
58 "sve_hybrid_bf16fp32_dot_6x4VL",
59 nullptr,
__anonf364b8c80302() 60 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)); },
__anonf364b8c80402() 61 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>(args); }
62 },
63 { // gemm_bf16_interleaved
64 GemmMethod::GEMM_INTERLEAVED,
65 "sve_interleaved_bf16fp32_dot_8x3VL",
__anonf364b8c80502() 66 [](const GemmArgs &args) { return (args._Ksize>2); },
67 nullptr,
__anonf364b8c80602() 68 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
69 },
70 # endif // SVE
71 { // gemm_bf16_interleaved
72 GemmMethod::GEMM_INTERLEAVED,
73 "a64_interleaved_bf16fp32_mmla_8x12",
__anonf364b8c80702() 74 [](const GemmArgs &args) { return (args._Ksize>4); },
75 nullptr,
__anonf364b8c80802() 76 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
77 },
78 {
79 GemmMethod::GEMM_HYBRID,
80 "a64_hybrid_bf16fp32_dot_6x16",
81 nullptr,
82 nullptr,
__anonf364b8c80902() 83 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>(args); }
84 },
85 { // gemm_bf16_interleaved
86 GemmMethod::GEMM_INTERLEAVED,
87 "a64_interleaved_bf16fp32_dot_8x12",
__anonf364b8c80a02() 88 [](const GemmArgs &args) { return (args._Ksize>2); },
89 nullptr,
__anonf364b8c80b02() 90 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
91 },
92 #endif // V8P6_BF
93 #ifdef __aarch64__
94 {
95 GemmMethod::GEMM_INTERLEAVED,
96 "a64_sgemm_8x12",
97 nullptr,
98 nullptr,
__anonf364b8c80c02() 99 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>(args); }
100 },
101 #elif defined(__arm__)
102 {
103 GemmMethod::GEMM_INTERLEAVED,
104 "sgemm_8x6",
105 nullptr,
106 nullptr,
107 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, bfloat16, float>(args); }
108 },
109 #else
110 # error "Unknown Architecture"
111 #endif
112 {
113 GemmMethod::DEFAULT,
114 "",
115 nullptr,
116 nullptr,
117 nullptr
118 }
119 };
120
121 template<>
gemm_implementation_list()122 const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, float>() {
123 return gemm_bf16_methods;
124 }
125
126 /* Explicitly instantiate the external functions for these types. */
127 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
128 template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
129 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
130
131 } // namespace arm_gemm
132