1 /*
2 * Copyright (c) 2019-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33
34 #include <map>
35 #include <utility>
36
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42
CLGEMMReshapedKernelConfigurationBifrost(GPUTarget gpu)43 CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifrost(GPUTarget gpu)
44 : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
51
52 // Configurations for Mali-G76
53 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
54 {
55 { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
56 { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
57 { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
58 { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
59 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
60 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
61 };
62
63 // Configurations for Mali-G52
64 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
65 {
66 { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32 },
67 { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f16 },
68 { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
69 { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
70 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
71 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
72 };
73
74 // Configurations for Mali-G7x
75 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
76 {
77 { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
78 { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
79 { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
80 { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
81 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
82 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
83 };
84
85 switch(_target)
86 {
87 case GPUTarget::G76:
88 if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
89 {
90 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
91 }
92 else
93 {
94 ARM_COMPUTE_ERROR("Not supported data type");
95 }
96 default:
97 if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
98 {
99 return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
100 }
101 else
102 {
103 ARM_COMPUTE_ERROR("Not supported data type");
104 }
105 }
106 }
107
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)108 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
109 {
110 ARM_COMPUTE_UNUSED(k);
111 ARM_COMPUTE_UNUSED(b);
112
113 if(n <= 4)
114 {
115 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
116 }
117 else
118 {
119 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
120 }
121 }
122
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)123 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
124 {
125 ARM_COMPUTE_UNUSED(k);
126 ARM_COMPUTE_UNUSED(b);
127
128 if(n <= 4)
129 {
130 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
131 }
132 else
133 {
134 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
135 }
136 }
137
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
139 {
140 ARM_COMPUTE_UNUSED(k);
141 ARM_COMPUTE_UNUSED(b);
142
143 if(dot8_supported(CLKernelLibrary::get().get_device()))
144 {
145 if(n <= 4)
146 {
147 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
148 }
149 else
150 {
151 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true);
152 }
153 }
154 else
155 {
156 if(n <= 4)
157 {
158 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
159 }
160 else
161 {
162 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true);
163 }
164 }
165 }
166
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)167 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
168 {
169 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
170 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
171 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
172 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
173
174 GEMMLHSMatrixInfo lhs_info_buf;
175 GEMMRHSMatrixInfo rhs_info_buf;
176 GEMMLHSMatrixInfo lhs_info_img;
177 GEMMRHSMatrixInfo rhs_info_img;
178
179 if(workload <= 274.4000f)
180 {
181 if(r_nk <= 0.7461f)
182 {
183 if(r_mn <= 21.1667f)
184 {
185 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
186 }
187 else
188 {
189 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
190 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
191
192 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
193 std::make_pair(lhs_info_buf, rhs_info_buf),
194 n, k, b, DataType::F32);
195 }
196 }
197 else
198 {
199 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
200 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
201
202 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
203 std::make_pair(lhs_info_buf, rhs_info_buf),
204 n, k, b, DataType::F32);
205 }
206 }
207 else
208 {
209 if(r_mk <= 17.3926f)
210 {
211 if(workload <= 542.4000f)
212 {
213 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
214 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
215
216 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
217 std::make_pair(lhs_info_buf, rhs_info_buf),
218 n, k, b, DataType::F32);
219 }
220 else
221 {
222 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
223 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
224
225 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
226 std::make_pair(lhs_info_buf, rhs_info_buf),
227 n, k, b, DataType::F32);
228 }
229 }
230 else
231 {
232 if(r_nk <= 0.5463f)
233 {
234 if(workload <= 11767.6001f)
235 {
236 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
237 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
238
239 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240 std::make_pair(lhs_info_buf, rhs_info_buf),
241 n, k, b, DataType::F32);
242 }
243 else
244 {
245 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
246 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
247
248 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
249 std::make_pair(lhs_info_buf, rhs_info_buf),
250 n, k, b, DataType::F32);
251 }
252 }
253 else
254 {
255 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
256 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
257
258 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
259 std::make_pair(lhs_info_buf, rhs_info_buf),
260 n, k, b, DataType::F32);
261 }
262 }
263 }
264 }
265
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)266 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
267 {
268 ARM_COMPUTE_UNUSED(k);
269
270 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
271
272 if(workload <= 323.4000f)
273 {
274 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
275 }
276 else
277 {
278 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
279 }
280 }
281
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)282 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
283 {
284 ARM_COMPUTE_UNUSED(k);
285 ARM_COMPUTE_UNUSED(b);
286
287 GEMMLHSMatrixInfo lhs_info_buf;
288 GEMMRHSMatrixInfo rhs_info_buf;
289 GEMMLHSMatrixInfo lhs_info_img;
290 GEMMRHSMatrixInfo rhs_info_img;
291
292 // Get lhs_info/rhs_info in case of OpenCL buffer
293 if(n <= 4)
294 {
295 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
296 }
297 else
298 {
299 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
300 }
301
302 // Get lhs_info/rhs_info in case of OpenCL image
303 // Condition on the GPU workload
304 if((m / 4) * (n / 4) >= 2560)
305 {
306 // Big workload
307 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
308 }
309 else
310 {
311 // Small workload
312 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
313 }
314
315 const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
316 const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
317 const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
318
319 // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
320 const bool use_cl_image2d = (n <= 4) ? false : true;
321
322 if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
323 {
324 return std::make_pair(lhs_info_img, rhs_info_img);
325 }
326 else
327 {
328 return std::make_pair(lhs_info_buf, rhs_info_buf);
329 }
330 }
331
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)332 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
333 {
334 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
335 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
336
337 if(workload <= 1595.2000f)
338 {
339 if(r_mk <= 2.1044f)
340 {
341 if(workload <= 870.4000f)
342 {
343 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
344 }
345 else
346 {
347 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
348 }
349 }
350 else
351 {
352 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
353 }
354 }
355 else
356 {
357 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false);
358 }
359 }
360
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)361 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
362 {
363 ARM_COMPUTE_UNUSED(k);
364 ARM_COMPUTE_UNUSED(b);
365
366 if(n <= 4)
367 {
368 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
369 }
370 else
371 {
372 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
373 }
374 }
375 } // namespace cl_gemm
376 } // namespace arm_compute
377