1 /*
2 * Copyright (c) 2019-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33
34 #include <map>
35 #include <utility>
36
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42
CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget gpu)43 CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget gpu)
44 : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
51 unsigned int b);
52
53 // Configurations for Mali-G51
54 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
55 {
56 { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 },
57 { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16 },
58 { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
59 { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
60 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
61 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
62 };
63
64 // Configurations for Mali-G52
65 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
66 {
67 { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 },
68 { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f16 },
69 { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
70 { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
71 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
72 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
73 };
74
75 // Configurations for Mali-G76
76 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
77 {
78 { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 },
79 { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16 },
80 { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
81 { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
82 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
83 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }
84 };
85
86 // Configurations for Mali-G7x
87 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
88 {
89 { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 },
90 { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 },
91 { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
92 { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
93 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
94 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
95 };
96
97 switch(_target)
98 {
99 case GPUTarget::G76:
100 if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
101 {
102 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
103 }
104 else
105 {
106 ARM_COMPUTE_ERROR("Not supported data type");
107 }
108 case GPUTarget::G52:
109 if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
110 {
111 return (this->*gemm_configs_G52[data_type])(m, n, k, b);
112 }
113 else
114 {
115 ARM_COMPUTE_ERROR("Not supported data type");
116 }
117 case GPUTarget::G51:
118 if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
119 {
120 return (this->*gemm_configs_G51[data_type])(m, n, k, b);
121 }
122 else
123 {
124 ARM_COMPUTE_ERROR("Not supported data type");
125 }
126 default:
127 if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
128 {
129 return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
130 }
131 else
132 {
133 ARM_COMPUTE_ERROR("Not supported data type");
134 }
135 }
136 }
137
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
139 {
140 ARM_COMPUTE_UNUSED(k);
141 ARM_COMPUTE_UNUSED(b);
142
143 if(m == 1)
144 {
145 if(n <= 2548)
146 {
147 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
148 }
149 else
150 {
151 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
152 }
153 }
154 else
155 {
156 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
157 }
158 }
159
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)160 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
161 {
162 ARM_COMPUTE_UNUSED(k);
163 ARM_COMPUTE_UNUSED(b);
164
165 GEMMLHSMatrixInfo lhs_info_buf;
166 GEMMRHSMatrixInfo rhs_info_buf;
167 GEMMLHSMatrixInfo lhs_info_img;
168 GEMMRHSMatrixInfo rhs_info_img;
169
170 const bool is_workload_big = ((m * n * b) / 16) >= 2048;
171
172 if(m == 1)
173 {
174 if(n >= 8192)
175 {
176 const unsigned int h0 = std::max(n / 4, 1U);
177 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
178 }
179 else
180 {
181 const unsigned int h0 = std::max(n / 2, 1U);
182 if(n <= 204)
183 {
184 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
185 }
186 else
187 {
188 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
189 }
190 }
191 }
192 else
193 {
194 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
195 if(is_workload_big)
196 {
197 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
198 }
199 else
200 {
201 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
202 }
203 }
204
205 // Get lhs_info/rhs_info in case of OpenCL image
206 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
207 if(is_workload_big)
208 {
209 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
210 }
211 else
212 {
213 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
214 }
215
216 const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
217 const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
218 const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
219
220 // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
221 const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
222
223 if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
224 {
225 return std::make_pair(lhs_info_img, rhs_info_img);
226 }
227 else
228 {
229 return std::make_pair(lhs_info_buf, rhs_info_buf);
230 }
231 }
232
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)233 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
234 {
235 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
236 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
237
238 GEMMLHSMatrixInfo lhs_info_buf;
239 GEMMRHSMatrixInfo rhs_info_buf;
240 GEMMLHSMatrixInfo lhs_info_img;
241 GEMMRHSMatrixInfo rhs_info_img;
242
243 if(m == 1)
244 {
245 if(r_nk <= 0.4664f)
246 {
247 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
248 }
249 else
250 {
251 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
252 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
253
254 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
255 std::make_pair(lhs_info_buf, rhs_info_buf),
256 n, k, b, DataType::F32);
257 }
258 }
259 else
260 {
261 if(workload <= 274.4000f)
262 {
263 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
264 }
265 else
266 {
267 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
268 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
269
270 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
271 std::make_pair(lhs_info_buf, rhs_info_buf),
272 n, k, b, DataType::F32);
273 }
274 }
275 }
276
configure_G51_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)277 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
278 {
279 ARM_COMPUTE_UNUSED(k);
280 ARM_COMPUTE_UNUSED(b);
281
282 if(m == 1)
283 {
284 const unsigned int n0 = n < 1280 ? 2 : 4;
285 const unsigned int h0 = std::max(n / n0, 1U);
286 return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
287 }
288 else
289 {
290 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
291 }
292 }
293
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)294 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
295 {
296 ARM_COMPUTE_UNUSED(k);
297 ARM_COMPUTE_UNUSED(b);
298
299 if(m == 1)
300 {
301 if(n > 2048)
302 {
303 const unsigned int h0 = std::max(n / 4, 1U);
304 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
305 }
306 else
307 {
308 const unsigned int h0 = std::max(n / 2, 1U);
309 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
310 }
311 }
312 else
313 {
314 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
315 }
316 }
317
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)318 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
319 {
320 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
321 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
322 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
323 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
324
325 GEMMLHSMatrixInfo lhs_info_buf;
326 GEMMRHSMatrixInfo rhs_info_buf;
327 GEMMLHSMatrixInfo lhs_info_img;
328 GEMMRHSMatrixInfo rhs_info_img;
329
330 if(m == 1)
331 {
332 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
333
334 if(r_mk <= 0.0026f)
335 {
336 if(r_nk <= 0.4664f)
337 {
338 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
339 }
340 else
341 {
342 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
343 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
344 std::make_pair(lhs_info_buf, rhs_info_buf),
345 n, k, b, DataType::F16);
346 }
347 }
348 else
349 {
350 if(r_mk <= 0.0148f)
351 {
352 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
353 }
354 else
355 {
356 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
357 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
358 std::make_pair(lhs_info_buf, rhs_info_buf),
359 n, k, b, DataType::F16);
360 }
361 }
362 }
363 else
364 {
365 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
366
367 if(workload <= 362.6000f)
368 {
369 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
370 }
371 else
372 {
373 if(r_mn <= 22.6067f)
374 {
375 if(workload <= 708.8000f)
376 {
377 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
378 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
379 std::make_pair(lhs_info_buf, rhs_info_buf),
380 n, k, b, DataType::F16);
381 }
382 else
383 {
384 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
385 }
386 }
387 else
388 {
389 if(r_nk <= 0.0917f)
390 {
391 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
392 }
393 else
394 {
395 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
396 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
397 std::make_pair(lhs_info_buf, rhs_info_buf),
398 n, k, b, DataType::F16);
399 }
400 }
401 }
402 }
403 }
404
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)405 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
406 {
407 ARM_COMPUTE_UNUSED(k);
408
409 if(m == 1)
410 {
411 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
412 }
413 else
414 {
415 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
416 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417
418 if(workload <= 7449.60f)
419 {
420 if(workload <= 691.60f)
421 {
422 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
423 }
424 else
425 {
426 if(workload <= 4155.20f)
427 {
428 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
429 }
430 else
431 {
432 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
433 }
434 }
435 }
436 else
437 {
438 if(workload <= 16300.80f)
439 {
440 if(r_mn <= 44.56f)
441 {
442 GEMMLHSMatrixInfo lhs_info_buf;
443 GEMMRHSMatrixInfo rhs_info_buf;
444 GEMMLHSMatrixInfo lhs_info_img;
445 GEMMRHSMatrixInfo rhs_info_img;
446
447 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
448 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
449
450 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
451 std::make_pair(lhs_info_buf, rhs_info_buf),
452 n, k, b, DataType::F16);
453 }
454 else
455 {
456 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
457 }
458 }
459 else
460 {
461 GEMMLHSMatrixInfo lhs_info_buf;
462 GEMMRHSMatrixInfo rhs_info_buf;
463 GEMMLHSMatrixInfo lhs_info_img;
464 GEMMRHSMatrixInfo rhs_info_img;
465
466 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
467 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
468
469 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
470 std::make_pair(lhs_info_buf, rhs_info_buf),
471 n, k, b, DataType::F16);
472 }
473 }
474 }
475 }
476
configure_G51_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)477 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
478 {
479 ARM_COMPUTE_UNUSED(k);
480 ARM_COMPUTE_UNUSED(b);
481
482 if(m == 1)
483 {
484 const unsigned int n0 = n < 1280 ? 2 : 4;
485 const unsigned int h0 = std::max(n / n0, 1U);
486 return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
487 }
488 else
489 {
490 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
491 }
492 }
493
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)494 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
495 {
496 ARM_COMPUTE_UNUSED(k);
497 ARM_COMPUTE_UNUSED(b);
498
499 if(dot8_supported(CLKernelLibrary::get().get_device()))
500 {
501 if(m == 1)
502 {
503 const unsigned int h0 = std::max(n / 2, 1U);
504 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
505 }
506 else
507 {
508 const unsigned int h0 = std::max(n / 4, 1U);
509 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
510 }
511 }
512 else
513 {
514 const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
515 if(m == 1)
516 {
517 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
518 }
519 else
520 {
521 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
522 }
523 }
524 }
525
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)526 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
527 {
528 ARM_COMPUTE_UNUSED(k);
529 ARM_COMPUTE_UNUSED(b);
530
531 if(m == 1)
532 {
533 const unsigned int h0 = std::max(n / 2, 1U);
534 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
535 }
536 else
537 {
538 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
539 }
540 }
541
configure_G51_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)542 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
543 {
544 ARM_COMPUTE_UNUSED(k);
545 ARM_COMPUTE_UNUSED(b);
546
547 if(m == 1)
548 {
549 const unsigned int h0 = std::max(n / 2, 1U);
550 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
551 }
552 else
553 {
554 const unsigned int h0 = std::max(n / 2, 1U);
555 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
556 }
557 }
558
559 } // namespace cl_gemm
560 } // namespace arm_compute
561