• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33 
34 #include <map>
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42 
CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget gpu)43 CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget gpu)
44     : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
51                                              unsigned int b);
52 
53     // Configurations for Mali-G51
54     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
55     {
56         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 },
57         { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16 },
58         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
59         { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
60         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 },
61         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
62     };
63 
64     // Configurations for Mali-G52
65     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
66     {
67         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 },
68         { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f16 },
69         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
70         { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
71         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
72         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
73     };
74 
75     // Configurations for Mali-G76
76     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
77     {
78         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 },
79         { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16 },
80         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
81         { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
82         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 },
83         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }
84     };
85 
86     // Configurations for Mali-G7x
87     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
88     {
89         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 },
90         { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 },
91         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
92         { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
93         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
94         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
95     };
96 
97     switch(_target)
98     {
99         case GPUTarget::G76:
100             if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
101             {
102                 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
103             }
104             else
105             {
106                 ARM_COMPUTE_ERROR("Not supported data type");
107             }
108         case GPUTarget::G52:
109             if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
110             {
111                 return (this->*gemm_configs_G52[data_type])(m, n, k, b);
112             }
113             else
114             {
115                 ARM_COMPUTE_ERROR("Not supported data type");
116             }
117         case GPUTarget::G51:
118             if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
119             {
120                 return (this->*gemm_configs_G51[data_type])(m, n, k, b);
121             }
122             else
123             {
124                 ARM_COMPUTE_ERROR("Not supported data type");
125             }
126         default:
127             if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
128             {
129                 return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
130             }
131             else
132             {
133                 ARM_COMPUTE_ERROR("Not supported data type");
134             }
135     }
136 }
137 
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
139 {
140     ARM_COMPUTE_UNUSED(k);
141     ARM_COMPUTE_UNUSED(b);
142 
143     if(m == 1)
144     {
145         if(n <= 2548)
146         {
147             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
148         }
149         else
150         {
151             return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
152         }
153     }
154     else
155     {
156         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
157     }
158 }
159 
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)160 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
161 {
162     ARM_COMPUTE_UNUSED(k);
163     ARM_COMPUTE_UNUSED(b);
164 
165     GEMMLHSMatrixInfo lhs_info_buf;
166     GEMMRHSMatrixInfo rhs_info_buf;
167     GEMMLHSMatrixInfo lhs_info_img;
168     GEMMRHSMatrixInfo rhs_info_img;
169 
170     const bool is_workload_big = ((m * n * b) / 16) >= 2048;
171 
172     if(m == 1)
173     {
174         if(n >= 8192)
175         {
176             const unsigned int h0 = std::max(n / 4, 1U);
177             return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
178         }
179         else
180         {
181             const unsigned int h0 = std::max(n / 2, 1U);
182             if(n <= 204)
183             {
184                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
185             }
186             else
187             {
188                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
189             }
190         }
191     }
192     else
193     {
194         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
195         if(is_workload_big)
196         {
197             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
198         }
199         else
200         {
201             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
202         }
203     }
204 
205     // Get lhs_info/rhs_info in case of OpenCL image
206     const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
207     if(is_workload_big)
208     {
209         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
210     }
211     else
212     {
213         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
214     }
215 
216     const TensorInfo  tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
217     const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
218     const TensorInfo  tensor_reshaped_info(shape, 1, DataType::F32);
219 
220     // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
221     const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
222 
223     if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
224     {
225         return std::make_pair(lhs_info_img, rhs_info_img);
226     }
227     else
228     {
229         return std::make_pair(lhs_info_buf, rhs_info_buf);
230     }
231 }
232 
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)233 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
234 {
235     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
236     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
237 
238     GEMMLHSMatrixInfo lhs_info_buf;
239     GEMMRHSMatrixInfo rhs_info_buf;
240     GEMMLHSMatrixInfo lhs_info_img;
241     GEMMRHSMatrixInfo rhs_info_img;
242 
243     if(m == 1)
244     {
245         if(r_nk <= 0.4664f)
246         {
247             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
248         }
249         else
250         {
251             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
252             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
253 
254             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
255                                        std::make_pair(lhs_info_buf, rhs_info_buf),
256                                        n, k, b, DataType::F32);
257         }
258     }
259     else
260     {
261         if(workload <= 274.4000f)
262         {
263             return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
264         }
265         else
266         {
267             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
268             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
269 
270             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
271                                        std::make_pair(lhs_info_buf, rhs_info_buf),
272                                        n, k, b, DataType::F32);
273         }
274     }
275 }
276 
configure_G51_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)277 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
278 {
279     ARM_COMPUTE_UNUSED(k);
280     ARM_COMPUTE_UNUSED(b);
281 
282     if(m == 1)
283     {
284         const unsigned int n0 = n < 1280 ? 2 : 4;
285         const unsigned int h0 = std::max(n / n0, 1U);
286         return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
287     }
288     else
289     {
290         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
291     }
292 }
293 
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)294 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
295 {
296     ARM_COMPUTE_UNUSED(k);
297     ARM_COMPUTE_UNUSED(b);
298 
299     if(m == 1)
300     {
301         if(n > 2048)
302         {
303             const unsigned int h0 = std::max(n / 4, 1U);
304             return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
305         }
306         else
307         {
308             const unsigned int h0 = std::max(n / 2, 1U);
309             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
310         }
311     }
312     else
313     {
314         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
315     }
316 }
317 
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)318 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
319 {
320     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
321     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
322     const float r_mk = static_cast<float>(m) / static_cast<float>(k);
323     const float r_nk = static_cast<float>(n) / static_cast<float>(k);
324 
325     GEMMLHSMatrixInfo lhs_info_buf;
326     GEMMRHSMatrixInfo rhs_info_buf;
327     GEMMLHSMatrixInfo lhs_info_img;
328     GEMMRHSMatrixInfo rhs_info_img;
329 
330     if(m == 1)
331     {
332         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
333 
334         if(r_mk <= 0.0026f)
335         {
336             if(r_nk <= 0.4664f)
337             {
338                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
339             }
340             else
341             {
342                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
343                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
344                                            std::make_pair(lhs_info_buf, rhs_info_buf),
345                                            n, k, b, DataType::F16);
346             }
347         }
348         else
349         {
350             if(r_mk <= 0.0148f)
351             {
352                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
353             }
354             else
355             {
356                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
357                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
358                                            std::make_pair(lhs_info_buf, rhs_info_buf),
359                                            n, k, b, DataType::F16);
360             }
361         }
362     }
363     else
364     {
365         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
366 
367         if(workload <= 362.6000f)
368         {
369             return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
370         }
371         else
372         {
373             if(r_mn <= 22.6067f)
374             {
375                 if(workload <= 708.8000f)
376                 {
377                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
378                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
379                                                std::make_pair(lhs_info_buf, rhs_info_buf),
380                                                n, k, b, DataType::F16);
381                 }
382                 else
383                 {
384                     return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
385                 }
386             }
387             else
388             {
389                 if(r_nk <= 0.0917f)
390                 {
391                     return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
392                 }
393                 else
394                 {
395                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
396                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
397                                                std::make_pair(lhs_info_buf, rhs_info_buf),
398                                                n, k, b, DataType::F16);
399                 }
400             }
401         }
402     }
403 }
404 
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)405 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
406 {
407     ARM_COMPUTE_UNUSED(k);
408 
409     if(m == 1)
410     {
411         return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
412     }
413     else
414     {
415         const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
416         const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417 
418         if(workload <= 7449.60f)
419         {
420             if(workload <= 691.60f)
421             {
422                 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
423             }
424             else
425             {
426                 if(workload <= 4155.20f)
427                 {
428                     return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
429                 }
430                 else
431                 {
432                     return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
433                 }
434             }
435         }
436         else
437         {
438             if(workload <= 16300.80f)
439             {
440                 if(r_mn <= 44.56f)
441                 {
442                     GEMMLHSMatrixInfo lhs_info_buf;
443                     GEMMRHSMatrixInfo rhs_info_buf;
444                     GEMMLHSMatrixInfo lhs_info_img;
445                     GEMMRHSMatrixInfo rhs_info_img;
446 
447                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
448                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
449 
450                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
451                                                std::make_pair(lhs_info_buf, rhs_info_buf),
452                                                n, k, b, DataType::F16);
453                 }
454                 else
455                 {
456                     return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
457                 }
458             }
459             else
460             {
461                 GEMMLHSMatrixInfo lhs_info_buf;
462                 GEMMRHSMatrixInfo rhs_info_buf;
463                 GEMMLHSMatrixInfo lhs_info_img;
464                 GEMMRHSMatrixInfo rhs_info_img;
465 
466                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
467                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
468 
469                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
470                                            std::make_pair(lhs_info_buf, rhs_info_buf),
471                                            n, k, b, DataType::F16);
472             }
473         }
474     }
475 }
476 
configure_G51_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)477 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
478 {
479     ARM_COMPUTE_UNUSED(k);
480     ARM_COMPUTE_UNUSED(b);
481 
482     if(m == 1)
483     {
484         const unsigned int n0 = n < 1280 ? 2 : 4;
485         const unsigned int h0 = std::max(n / n0, 1U);
486         return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
487     }
488     else
489     {
490         return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
491     }
492 }
493 
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)494 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
495 {
496     ARM_COMPUTE_UNUSED(k);
497     ARM_COMPUTE_UNUSED(b);
498 
499     if(dot8_supported(CLKernelLibrary::get().get_device()))
500     {
501         if(m == 1)
502         {
503             const unsigned int h0 = std::max(n / 2, 1U);
504             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
505         }
506         else
507         {
508             const unsigned int h0 = std::max(n / 4, 1U);
509             return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
510         }
511     }
512     else
513     {
514         const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
515         if(m == 1)
516         {
517             return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
518         }
519         else
520         {
521             return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
522         }
523     }
524 }
525 
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)526 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
527 {
528     ARM_COMPUTE_UNUSED(k);
529     ARM_COMPUTE_UNUSED(b);
530 
531     if(m == 1)
532     {
533         const unsigned int h0 = std::max(n / 2, 1U);
534         return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
535     }
536     else
537     {
538         return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
539     }
540 }
541 
configure_G51_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)542 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
543 {
544     ARM_COMPUTE_UNUSED(k);
545     ARM_COMPUTE_UNUSED(b);
546 
547     if(m == 1)
548     {
549         const unsigned int h0 = std::max(n / 2, 1U);
550         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
551     }
552     else
553     {
554         const unsigned int h0 = std::max(n / 2, 1U);
555         return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
556     }
557 }
558 
559 } // namespace cl_gemm
560 } // namespace arm_compute
561