1 /*
2 * Copyright (c) 2019-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "src/core/CL/gemm/CLGEMMHelpers.h"
30
31 #include <map>
32 #include <utility>
33
34 namespace arm_compute
35 {
36 namespace cl_gemm
37 {
CLGEMMNativeKernelConfigurationBifrost(GPUTarget gpu)38 CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(GPUTarget gpu)
39 : ICLGEMMKernelConfiguration(gpu)
40 {
41 }
42
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)43 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
44 {
45 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
46 unsigned int b);
47
48 // Configurations for Mali-G71
49 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G71 =
50 {
51 { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 },
52 { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
53 { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
54 { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
55 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }
56 };
57
58 // Configurations for Mali-G76
59 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
60 {
61 { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 },
62 { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
63 { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
64 { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
65 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }
66 };
67
68 // Default configurations
69 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_default =
70 {
71 { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 },
72 { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
73 { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
74 { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
75 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }
76 };
77
78 switch(_target)
79 {
80 case GPUTarget::G71:
81 if(gemm_configs_G71.find(data_type) != gemm_configs_G71.end())
82 {
83 return (this->*gemm_configs_G71[data_type])(m, n, k, b);
84 }
85 else
86 {
87 ARM_COMPUTE_ERROR("Not supported data type");
88 }
89 case GPUTarget::G76:
90 if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
91 {
92 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
93 }
94 else
95 {
96 ARM_COMPUTE_ERROR("Not supported data type");
97 }
98 default:
99 if(gemm_configs_default.find(data_type) != gemm_configs_default.end())
100 {
101 return (this->*gemm_configs_default[data_type])(m, n, k, b);
102 }
103 else
104 {
105 ARM_COMPUTE_ERROR("Not supported data type");
106 }
107 }
108 }
109
configure_G71_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)110 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
111 {
112 ARM_COMPUTE_UNUSED(k);
113 ARM_COMPUTE_UNUSED(b);
114
115 if(m == 1)
116 {
117 if(n < 2048)
118 {
119 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
120 }
121 else if(n >= 2048 && n < 8192)
122 {
123 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false);
124 }
125 else
126 {
127 return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false);
128 }
129 }
130 else
131 {
132 return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1, false, false, false, false);
133 }
134 }
135
configure_G71_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)136 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
137 {
138 ARM_COMPUTE_UNUSED(k);
139 ARM_COMPUTE_UNUSED(b);
140
141 if(dot8_supported(CLKernelLibrary::get().get_device()))
142 {
143 if(m == 1)
144 {
145 if(n < 2048)
146 {
147 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
148 }
149 else if(n >= 2048 && n < 16384)
150 {
151 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
152 }
153 else
154 {
155 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
156 }
157 }
158 else
159 {
160 if(m < 64)
161 {
162 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
163 }
164 else
165 {
166 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
167 }
168 }
169 }
170 else
171 {
172 if(m == 1)
173 {
174 if(n < 8192)
175 {
176 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
177 }
178 else
179 {
180 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
181 }
182 }
183 else
184 {
185 return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false);
186 }
187 }
188 }
189
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)190 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
191 {
192 ARM_COMPUTE_UNUSED(k);
193 ARM_COMPUTE_UNUSED(b);
194
195 if(m == 1)
196 {
197 if(n > 4196)
198 {
199 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1, false, false, false, false);
200 }
201 else
202 {
203 if(k < 2048)
204 {
205 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1, false, false, false, false);
206 }
207 else if(k >= 2048 && k < 16384)
208 {
209 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
210 }
211 else
212 {
213 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1, false, false, false, false);
214 }
215 }
216 }
217 else
218 {
219 return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1, false, false, false, false);
220 }
221 }
222
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)223 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
224 {
225 ARM_COMPUTE_UNUSED(k);
226 ARM_COMPUTE_UNUSED(b);
227
228 if(m == 1)
229 {
230 if(n < 2048)
231 {
232 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
233 }
234 else if(n >= 2048 && n < 16384)
235 {
236 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
237 }
238 else
239 {
240 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
241 }
242 }
243 else
244 {
245 if(m < 64)
246 {
247 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
248 }
249 else
250 {
251 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
252 }
253 }
254 }
255
configure_default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)256 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
257 {
258 ARM_COMPUTE_UNUSED(k);
259 ARM_COMPUTE_UNUSED(b);
260
261 return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1, false, false, false, false);
262 }
263
configure_default_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)264 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
265 {
266 ARM_COMPUTE_UNUSED(k);
267 ARM_COMPUTE_UNUSED(b);
268
269 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
270 }
271 } // namespace cl_gemm
272 } // namespace arm_compute