1 /*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
33 #include <utility>
34
35 namespace arm_compute
36 {
37 namespace opencl
38 {
39 namespace kernels
40 {
41 namespace gemm
42 {
43 using namespace arm_compute::misc::shape_calculator;
44
ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu)45 ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu)
46 : IClGemmKernelConfig(gpu)
47 {
48 }
49
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)50 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
51 {
52 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
53 unsigned int b);
54
55 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
56 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
57 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
58
59 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
60 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
61 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
62
63 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
64 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
65 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
66
67 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
68 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
69 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
70
71 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
72 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
73 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
74
75 ConfigurationFunctionExecutorPtr func = nullptr;
76 switch(_target)
77 {
78 case GPUTarget::G76:
79 func = configs_G76.get_function(data_type);
80 break;
81 case GPUTarget::G51:
82 func = configs_G51.get_function(data_type);
83 break;
84 case GPUTarget::G52:
85 func = configs_G52.get_function(data_type);
86 break;
87 case GPUTarget::G31:
88 func = configs_G31.get_function(data_type);
89 break;
90 default:
91 func = configs_G7x.get_function(data_type);
92 break;
93 }
94
95 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
96 return (this->*func)(m, n, k, b);
97 }
98
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)99 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
100 {
101 ARM_COMPUTE_UNUSED(k);
102 ARM_COMPUTE_UNUSED(b);
103
104 if(m == 1)
105 {
106 if(n <= 2548)
107 {
108 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
109 }
110 else
111 {
112 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
113 }
114 }
115 else
116 {
117 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
118 }
119 }
120
configure_G31_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)121 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
122 {
123 ARM_COMPUTE_UNUSED(k);
124 ARM_COMPUTE_UNUSED(b);
125
126 if(m == 1)
127 {
128 const unsigned int h0 = std::max(n / 2, 1U);
129 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
130 }
131 else
132 {
133 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
134 if(m >= 28)
135 {
136 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
137 }
138 else
139 {
140 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
141 }
142 }
143 }
144
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
146 {
147 ARM_COMPUTE_UNUSED(k);
148 ARM_COMPUTE_UNUSED(b);
149
150 GEMMLHSMatrixInfo lhs_info_buf;
151 GEMMRHSMatrixInfo rhs_info_buf;
152 GEMMLHSMatrixInfo lhs_info_img;
153 GEMMRHSMatrixInfo rhs_info_img;
154
155 const bool is_workload_big = ((m * n * b) / 16) >= 2048;
156
157 if(m == 1)
158 {
159 if(n >= 8192)
160 {
161 const unsigned int h0 = std::max(n / 4, 1U);
162 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
163 }
164 else
165 {
166 const unsigned int h0 = std::max(n / 2, 1U);
167 if(n <= 204)
168 {
169 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
170 }
171 else
172 {
173 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
174 }
175 }
176 }
177 else
178 {
179 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
180 if(is_workload_big)
181 {
182 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
183 }
184 else
185 {
186 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
187 }
188 }
189
190 // Get lhs_info/rhs_info in case of OpenCL image
191 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
192 if(is_workload_big)
193 {
194 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
195 }
196 else
197 {
198 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
199 }
200
201 const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
202 const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
203 const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
204
205 // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
206 const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
207
208 if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
209 {
210 return std::make_pair(lhs_info_img, rhs_info_img);
211 }
212 else
213 {
214 return std::make_pair(lhs_info_buf, rhs_info_buf);
215 }
216 }
217
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)218 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
219 {
220 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
222
223 GEMMLHSMatrixInfo lhs_info_buf;
224 GEMMRHSMatrixInfo rhs_info_buf;
225 GEMMLHSMatrixInfo lhs_info_img;
226 GEMMRHSMatrixInfo rhs_info_img;
227
228 if(m == 1)
229 {
230 if(r_nk <= 0.4664f)
231 {
232 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
233 }
234 else
235 {
236 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
237 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
238
239 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240 std::make_pair(lhs_info_buf, rhs_info_buf),
241 n, k, b, DataType::F32);
242 }
243 }
244 else
245 {
246 if(workload <= 274.4000f)
247 {
248 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
249 }
250 else
251 {
252 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
253 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
254
255 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
256 std::make_pair(lhs_info_buf, rhs_info_buf),
257 n, k, b, DataType::F32);
258 }
259 }
260 }
261
configure_G51_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)262 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
263 {
264 ARM_COMPUTE_UNUSED(k);
265 ARM_COMPUTE_UNUSED(b);
266
267 if(m == 1)
268 {
269 const unsigned int n0 = n < 1280 ? 2 : 4;
270 const unsigned int h0 = std::max(n / n0, 1U);
271 return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
272 }
273 else
274 {
275 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
276 }
277 }
278
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)279 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
280 {
281 ARM_COMPUTE_UNUSED(k);
282 ARM_COMPUTE_UNUSED(b);
283
284 if(m == 1)
285 {
286 if(n > 2048)
287 {
288 const unsigned int h0 = std::max(n / 4, 1U);
289 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
290 }
291 else
292 {
293 const unsigned int h0 = std::max(n / 2, 1U);
294 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
295 }
296 }
297 else
298 {
299 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
300 }
301 }
302
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)303 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
304 {
305 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
306 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
307 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
308 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
309
310 GEMMLHSMatrixInfo lhs_info_buf;
311 GEMMRHSMatrixInfo rhs_info_buf;
312 GEMMLHSMatrixInfo lhs_info_img;
313 GEMMRHSMatrixInfo rhs_info_img;
314
315 if(m == 1)
316 {
317 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
318
319 if(r_mk <= 0.0026f)
320 {
321 if(r_nk <= 0.4664f)
322 {
323 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
324 }
325 else
326 {
327 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
328 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
329 std::make_pair(lhs_info_buf, rhs_info_buf),
330 n, k, b, DataType::F16);
331 }
332 }
333 else
334 {
335 if(r_mk <= 0.0148f)
336 {
337 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
338 }
339 else
340 {
341 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
342 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
343 std::make_pair(lhs_info_buf, rhs_info_buf),
344 n, k, b, DataType::F16);
345 }
346 }
347 }
348 else
349 {
350 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
351
352 if(workload <= 362.6000f)
353 {
354 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
355 }
356 else
357 {
358 if(r_mn <= 22.6067f)
359 {
360 if(workload <= 708.8000f)
361 {
362 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
363 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
364 std::make_pair(lhs_info_buf, rhs_info_buf),
365 n, k, b, DataType::F16);
366 }
367 else
368 {
369 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
370 }
371 }
372 else
373 {
374 if(r_nk <= 0.0917f)
375 {
376 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
377 }
378 else
379 {
380 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
381 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
382 std::make_pair(lhs_info_buf, rhs_info_buf),
383 n, k, b, DataType::F16);
384 }
385 }
386 }
387 }
388 }
389
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)390 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
391 {
392 ARM_COMPUTE_UNUSED(k);
393
394 if(m == 1)
395 {
396 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
397 }
398 else
399 {
400 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
401 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
402
403 if(workload <= 7449.60f)
404 {
405 if(workload <= 691.60f)
406 {
407 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
408 }
409 else
410 {
411 if(workload <= 4155.20f)
412 {
413 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
414 }
415 else
416 {
417 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
418 }
419 }
420 }
421 else
422 {
423 if(workload <= 16300.80f)
424 {
425 if(r_mn <= 44.56f)
426 {
427 GEMMLHSMatrixInfo lhs_info_buf;
428 GEMMRHSMatrixInfo rhs_info_buf;
429 GEMMLHSMatrixInfo lhs_info_img;
430 GEMMRHSMatrixInfo rhs_info_img;
431
432 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
433 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
434
435 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
436 std::make_pair(lhs_info_buf, rhs_info_buf),
437 n, k, b, DataType::F16);
438 }
439 else
440 {
441 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
442 }
443 }
444 else
445 {
446 GEMMLHSMatrixInfo lhs_info_buf;
447 GEMMRHSMatrixInfo rhs_info_buf;
448 GEMMLHSMatrixInfo lhs_info_img;
449 GEMMRHSMatrixInfo rhs_info_img;
450
451 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
452 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
453
454 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
455 std::make_pair(lhs_info_buf, rhs_info_buf),
456 n, k, b, DataType::F16);
457 }
458 }
459 }
460 }
461
configure_G51_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)462 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
463 {
464 ARM_COMPUTE_UNUSED(k);
465 ARM_COMPUTE_UNUSED(b);
466
467 if(m == 1)
468 {
469 const unsigned int n0 = n < 1280 ? 2 : 4;
470 const unsigned int h0 = std::max(n / n0, 1U);
471 return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
472 }
473 else
474 {
475 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
476 }
477 }
478
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)479 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
480 {
481 ARM_COMPUTE_UNUSED(k);
482 ARM_COMPUTE_UNUSED(b);
483
484 if(dot8_supported(CLKernelLibrary::get().get_device()))
485 {
486 if(m == 1)
487 {
488 const unsigned int h0 = std::max(n / 2, 1U);
489 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
490 }
491 else
492 {
493 const unsigned int h0 = std::max(n / 4, 1U);
494 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
495 }
496 }
497 else
498 {
499 const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
500 if(m == 1)
501 {
502 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
503 }
504 else
505 {
506 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
507 }
508 }
509 }
510
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)511 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
512 {
513 ARM_COMPUTE_UNUSED(k);
514 ARM_COMPUTE_UNUSED(b);
515
516 if(m == 1)
517 {
518 const unsigned int h0 = std::max(n / 2, 1U);
519 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
520 }
521 else
522 {
523 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
524 }
525 }
526
configure_G51_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)527 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
528 {
529 ARM_COMPUTE_UNUSED(k);
530 ARM_COMPUTE_UNUSED(b);
531
532 if(m == 1)
533 {
534 const unsigned int h0 = std::max(n / 2, 1U);
535 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
536 }
537 else
538 {
539 const unsigned int h0 = std::max(n / 2, 1U);
540 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
541 }
542 }
543
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
548