• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_gemm.hpp"
25 #include "bfloat.hpp"
26 #include "gemm_common.hpp"
27 #include "gemm_hybrid.hpp"
28 #include "gemm_hybrid_indirect.hpp"
29 #include "gemm_implementation.hpp"
30 #include "gemm_interleaved.hpp"
31 #include "gemv_batched.hpp"
32 #include "gemv_pretransposed.hpp"
33 
34 #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
35 #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
36 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
37 #include "kernels/a64_sgemm_8x12.hpp"
38 #include "kernels/a32_sgemm_8x6.hpp"
39 #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
40 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
41 #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
42 
43 namespace arm_gemm {
44 
45 static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
46 {
47 #ifdef V8P6_BF
48 #ifdef __ARM_FEATURE_SVE
49 { // gemm_bf16_interleaved
50     GemmMethod::GEMM_INTERLEAVED,
51     "sve_interleaved_bf16fp32_mmla_8x3VL",
__anonf364b8c80102() 52     [](const GemmArgs &args) { return (args._Ksize>4); },
53     nullptr,
__anonf364b8c80202() 54     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
55 },
56 {
57     GemmMethod::GEMM_HYBRID,
58     "sve_hybrid_bf16fp32_dot_6x4VL",
59     nullptr,
__anonf364b8c80302() 60     [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)); },
__anonf364b8c80402() 61     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>(args); }
62 },
63 { // gemm_bf16_interleaved
64     GemmMethod::GEMM_INTERLEAVED,
65     "sve_interleaved_bf16fp32_dot_8x3VL",
__anonf364b8c80502() 66     [](const GemmArgs &args) { return (args._Ksize>2); },
67     nullptr,
__anonf364b8c80602() 68     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
69 },
70 # endif // SVE
71 { // gemm_bf16_interleaved
72     GemmMethod::GEMM_INTERLEAVED,
73     "a64_interleaved_bf16fp32_mmla_8x12",
__anonf364b8c80702() 74     [](const GemmArgs &args) { return (args._Ksize>4); },
75     nullptr,
__anonf364b8c80802() 76     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
77 },
78 {
79     GemmMethod::GEMM_HYBRID,
80     "a64_hybrid_bf16fp32_dot_6x16",
81     nullptr,
82     nullptr,
__anonf364b8c80902() 83     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>(args); }
84 },
85 { // gemm_bf16_interleaved
86     GemmMethod::GEMM_INTERLEAVED,
87     "a64_interleaved_bf16fp32_dot_8x12",
__anonf364b8c80a02() 88     [](const GemmArgs &args) { return (args._Ksize>2); },
89     nullptr,
__anonf364b8c80b02() 90     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
91 },
92 #endif // V8P6_BF
93 #ifdef __aarch64__
94 {
95     GemmMethod::GEMM_INTERLEAVED,
96     "a64_sgemm_8x12",
97     nullptr,
98     nullptr,
__anonf364b8c80c02() 99     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>(args); }
100 },
101 #elif defined(__arm__)
102 {
103     GemmMethod::GEMM_INTERLEAVED,
104     "sgemm_8x6",
105     nullptr,
106     nullptr,
107     [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, bfloat16, float>(args); }
108 },
109 #else
110 # error "Unknown Architecture"
111 #endif
112 {
113     GemmMethod::DEFAULT,
114     "",
115     nullptr,
116     nullptr,
117     nullptr
118 }
119 };
120 
121 template<>
gemm_implementation_list()122 const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, float>() {
123     return gemm_bf16_methods;
124 }
125 
126 /* Explicitly instantiate the external functions for these types. */
127 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
128 template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
129 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
130 
131 } // namespace arm_gemm
132