• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
33 #include <utility>
34 
35 namespace arm_compute
36 {
37 namespace opencl
38 {
39 namespace kernels
40 {
41 namespace gemm
42 {
43 using namespace arm_compute::misc::shape_calculator;
44 
ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu)45 ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu)
46     : IClGemmKernelConfig(gpu)
47 {
48 }
49 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)50 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
51 {
52     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
53                                              unsigned int b);
54 
55     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
56                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
57                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
58 
59     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
60                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
61                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
62 
63     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
64                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
65                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
66 
67     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
68                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
69                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
70 
71     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
72                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
73                                                                     &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
74 
75     ConfigurationFunctionExecutorPtr func = nullptr;
76     switch(_target)
77     {
78         case GPUTarget::G76:
79             func = configs_G76.get_function(data_type);
80             break;
81         case GPUTarget::G51:
82             func = configs_G51.get_function(data_type);
83             break;
84         case GPUTarget::G52:
85             func = configs_G52.get_function(data_type);
86             break;
87         case GPUTarget::G31:
88             func = configs_G31.get_function(data_type);
89             break;
90         default:
91             func = configs_G7x.get_function(data_type);
92             break;
93     }
94 
95     ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
96     return (this->*func)(m, n, k, b);
97 }
98 
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)99 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
100 {
101     ARM_COMPUTE_UNUSED(k);
102     ARM_COMPUTE_UNUSED(b);
103 
104     if(m == 1)
105     {
106         if(n <= 2548)
107         {
108             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
109         }
110         else
111         {
112             return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
113         }
114     }
115     else
116     {
117         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
118     }
119 }
120 
configure_G31_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)121 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
122 {
123     ARM_COMPUTE_UNUSED(k);
124     ARM_COMPUTE_UNUSED(b);
125 
126     if(m == 1)
127     {
128         const unsigned int h0 = std::max(n / 2, 1U);
129         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
130     }
131     else
132     {
133         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
134         if(m >= 28)
135         {
136             return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
137         }
138         else
139         {
140             return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
141         }
142     }
143 }
144 
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
146 {
147     ARM_COMPUTE_UNUSED(k);
148     ARM_COMPUTE_UNUSED(b);
149 
150     GEMMLHSMatrixInfo lhs_info_buf;
151     GEMMRHSMatrixInfo rhs_info_buf;
152     GEMMLHSMatrixInfo lhs_info_img;
153     GEMMRHSMatrixInfo rhs_info_img;
154 
155     const bool is_workload_big = ((m * n * b) / 16) >= 2048;
156 
157     if(m == 1)
158     {
159         if(n >= 8192)
160         {
161             const unsigned int h0 = std::max(n / 4, 1U);
162             return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
163         }
164         else
165         {
166             const unsigned int h0 = std::max(n / 2, 1U);
167             if(n <= 204)
168             {
169                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
170             }
171             else
172             {
173                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
174             }
175         }
176     }
177     else
178     {
179         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
180         if(is_workload_big)
181         {
182             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
183         }
184         else
185         {
186             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
187         }
188     }
189 
190     // Get lhs_info/rhs_info in case of OpenCL image
191     const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
192     if(is_workload_big)
193     {
194         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
195     }
196     else
197     {
198         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
199     }
200 
201     const TensorInfo  tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
202     const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
203     const TensorInfo  tensor_reshaped_info(shape, 1, DataType::F32);
204 
205     // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
206     const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
207 
208     if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
209     {
210         return std::make_pair(lhs_info_img, rhs_info_img);
211     }
212     else
213     {
214         return std::make_pair(lhs_info_buf, rhs_info_buf);
215     }
216 }
217 
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)218 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
219 {
220     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
222 
223     GEMMLHSMatrixInfo lhs_info_buf;
224     GEMMRHSMatrixInfo rhs_info_buf;
225     GEMMLHSMatrixInfo lhs_info_img;
226     GEMMRHSMatrixInfo rhs_info_img;
227 
228     if(m == 1)
229     {
230         if(r_nk <= 0.4664f)
231         {
232             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
233         }
234         else
235         {
236             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
237             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
238 
239             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240                                        std::make_pair(lhs_info_buf, rhs_info_buf),
241                                        n, k, b, DataType::F32);
242         }
243     }
244     else
245     {
246         if(workload <= 274.4000f)
247         {
248             return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
249         }
250         else
251         {
252             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
253             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
254 
255             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
256                                        std::make_pair(lhs_info_buf, rhs_info_buf),
257                                        n, k, b, DataType::F32);
258         }
259     }
260 }
261 
configure_G51_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)262 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
263 {
264     ARM_COMPUTE_UNUSED(k);
265     ARM_COMPUTE_UNUSED(b);
266 
267     if(m == 1)
268     {
269         const unsigned int n0 = n < 1280 ? 2 : 4;
270         const unsigned int h0 = std::max(n / n0, 1U);
271         return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
272     }
273     else
274     {
275         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
276     }
277 }
278 
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)279 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
280 {
281     ARM_COMPUTE_UNUSED(k);
282     ARM_COMPUTE_UNUSED(b);
283 
284     if(m == 1)
285     {
286         if(n > 2048)
287         {
288             const unsigned int h0 = std::max(n / 4, 1U);
289             return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
290         }
291         else
292         {
293             const unsigned int h0 = std::max(n / 2, 1U);
294             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
295         }
296     }
297     else
298     {
299         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
300     }
301 }
302 
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)303 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
304 {
305     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
306     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
307     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
308     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
309 
310     GEMMLHSMatrixInfo lhs_info_buf;
311     GEMMRHSMatrixInfo rhs_info_buf;
312     GEMMLHSMatrixInfo lhs_info_img;
313     GEMMRHSMatrixInfo rhs_info_img;
314 
315     if(m == 1)
316     {
317         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
318 
319         if(r_mk <= 0.0026f)
320         {
321             if(r_nk <= 0.4664f)
322             {
323                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
324             }
325             else
326             {
327                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
328                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
329                                            std::make_pair(lhs_info_buf, rhs_info_buf),
330                                            n, k, b, DataType::F16);
331             }
332         }
333         else
334         {
335             if(r_mk <= 0.0148f)
336             {
337                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
338             }
339             else
340             {
341                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
342                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
343                                            std::make_pair(lhs_info_buf, rhs_info_buf),
344                                            n, k, b, DataType::F16);
345             }
346         }
347     }
348     else
349     {
350         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
351 
352         if(workload <= 362.6000f)
353         {
354             return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
355         }
356         else
357         {
358             if(r_mn <= 22.6067f)
359             {
360                 if(workload <= 708.8000f)
361                 {
362                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
363                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
364                                                std::make_pair(lhs_info_buf, rhs_info_buf),
365                                                n, k, b, DataType::F16);
366                 }
367                 else
368                 {
369                     return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
370                 }
371             }
372             else
373             {
374                 if(r_nk <= 0.0917f)
375                 {
376                     return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
377                 }
378                 else
379                 {
380                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
381                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
382                                                std::make_pair(lhs_info_buf, rhs_info_buf),
383                                                n, k, b, DataType::F16);
384                 }
385             }
386         }
387     }
388 }
389 
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)390 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
391 {
392     ARM_COMPUTE_UNUSED(k);
393 
394     if(m == 1)
395     {
396         return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
397     }
398     else
399     {
400         const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
401         const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
402 
403         if(workload <= 7449.60f)
404         {
405             if(workload <= 691.60f)
406             {
407                 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
408             }
409             else
410             {
411                 if(workload <= 4155.20f)
412                 {
413                     return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
414                 }
415                 else
416                 {
417                     return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
418                 }
419             }
420         }
421         else
422         {
423             if(workload <= 16300.80f)
424             {
425                 if(r_mn <= 44.56f)
426                 {
427                     GEMMLHSMatrixInfo lhs_info_buf;
428                     GEMMRHSMatrixInfo rhs_info_buf;
429                     GEMMLHSMatrixInfo lhs_info_img;
430                     GEMMRHSMatrixInfo rhs_info_img;
431 
432                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
433                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
434 
435                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
436                                                std::make_pair(lhs_info_buf, rhs_info_buf),
437                                                n, k, b, DataType::F16);
438                 }
439                 else
440                 {
441                     return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
442                 }
443             }
444             else
445             {
446                 GEMMLHSMatrixInfo lhs_info_buf;
447                 GEMMRHSMatrixInfo rhs_info_buf;
448                 GEMMLHSMatrixInfo lhs_info_img;
449                 GEMMRHSMatrixInfo rhs_info_img;
450 
451                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
452                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
453 
454                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
455                                            std::make_pair(lhs_info_buf, rhs_info_buf),
456                                            n, k, b, DataType::F16);
457             }
458         }
459     }
460 }
461 
configure_G51_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)462 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
463 {
464     ARM_COMPUTE_UNUSED(k);
465     ARM_COMPUTE_UNUSED(b);
466 
467     if(m == 1)
468     {
469         const unsigned int n0 = n < 1280 ? 2 : 4;
470         const unsigned int h0 = std::max(n / n0, 1U);
471         return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
472     }
473     else
474     {
475         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
476     }
477 }
478 
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)479 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
480 {
481     ARM_COMPUTE_UNUSED(k);
482     ARM_COMPUTE_UNUSED(b);
483 
484     if(dot8_supported(CLKernelLibrary::get().get_device()))
485     {
486         if(m == 1)
487         {
488             const unsigned int h0 = std::max(n / 2, 1U);
489             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
490         }
491         else
492         {
493             const unsigned int h0 = std::max(n / 4, 1U);
494             return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
495         }
496     }
497     else
498     {
499         const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
500         if(m == 1)
501         {
502             return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
503         }
504         else
505         {
506             return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
507         }
508     }
509 }
510 
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)511 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
512 {
513     ARM_COMPUTE_UNUSED(k);
514     ARM_COMPUTE_UNUSED(b);
515 
516     if(m == 1)
517     {
518         const unsigned int h0 = std::max(n / 2, 1U);
519         return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
520     }
521     else
522     {
523         return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
524     }
525 }
526 
configure_G51_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)527 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
528 {
529     ARM_COMPUTE_UNUSED(k);
530     ARM_COMPUTE_UNUSED(b);
531 
532     if(m == 1)
533     {
534         const unsigned int h0 = std::max(n / 2, 1U);
535         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
536     }
537     else
538     {
539         const unsigned int h0 = std::max(n / 2, 1U);
540         return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
541     }
542 }
543 
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
548