• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33 
34 #include <map>
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42 
CLGEMMReshapedOnlyRHSKernelConfigurationValhall(GPUTarget gpu)43 CLGEMMReshapedOnlyRHSKernelConfigurationValhall::CLGEMMReshapedOnlyRHSKernelConfigurationValhall(GPUTarget gpu)
44     : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationValhall::*)(unsigned int m, unsigned int n, unsigned int k,
51                                              unsigned int b);
52 
53     // Configurations for Mali-G77
54     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G77 =
55     {
56         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32 },
57         { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16 },
58         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
59         { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
60         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
61         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }
62     };
63 
64     switch(_target)
65     {
66         case GPUTarget::G77:
67         default:
68             if(gemm_configs_G77.find(data_type) != gemm_configs_G77.end())
69             {
70                 return (this->*gemm_configs_G77[data_type])(m, n, k, b);
71             }
72             else
73             {
74                 ARM_COMPUTE_ERROR("Not supported data type");
75             }
76     }
77 }
78 
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)79 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
80 {
81     if(m == 1)
82     {
83         const float r_mn = static_cast<float>(m) / static_cast<float>(n);
84         const float r_mk = static_cast<float>(m) / static_cast<float>(k);
85 
86         if(r_mk <= 0.0064484127797186375)
87         {
88             if(r_mn <= 0.0028273810748942196)
89             {
90                 GEMMLHSMatrixInfo lhs_info_buf;
91                 GEMMRHSMatrixInfo rhs_info_buf;
92                 GEMMLHSMatrixInfo lhs_info_img;
93                 GEMMRHSMatrixInfo rhs_info_img;
94 
95                 const unsigned int h0 = std::max(n / 4, 1U);
96                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, false, true);
97                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true, false);
98 
99                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
100                                            std::make_pair(lhs_info_buf, rhs_info_buf),
101                                            n, k, b, DataType::F32);
102             }
103             else
104             {
105                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, false, true, false, false, false);
106             }
107         }
108         else
109         {
110             if(r_mk <= 0.020312500186264515)
111             {
112                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, false, false);
113             }
114             else
115             {
116                 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, true, false);
117             }
118         }
119     }
120     else
121     {
122         const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
123         const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
124         const float r_mk = static_cast<float>(m) / static_cast<float>(k);
125 
126         if(workload <= 1999.2000122070312)
127         {
128             if(workload <= 747.1999816894531)
129             {
130                 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
131             }
132             else
133             {
134                 GEMMLHSMatrixInfo lhs_info_buf;
135                 GEMMRHSMatrixInfo rhs_info_buf;
136                 GEMMLHSMatrixInfo lhs_info_img;
137                 GEMMRHSMatrixInfo rhs_info_img;
138                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
139                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
140 
141                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
142                                            std::make_pair(lhs_info_buf, rhs_info_buf),
143                                            n, k, b, DataType::F32);
144             }
145         }
146         else
147         {
148             if(r_mn <= 0.03348214365541935)
149             {
150                 if(r_mk <= 0.028125000186264515)
151                 {
152                     return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
153                 }
154                 else
155                 {
156                     GEMMLHSMatrixInfo lhs_info_buf;
157                     GEMMRHSMatrixInfo rhs_info_buf;
158                     GEMMLHSMatrixInfo lhs_info_img;
159                     GEMMRHSMatrixInfo rhs_info_img;
160                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
161                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
162 
163                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
164                                                std::make_pair(lhs_info_buf, rhs_info_buf),
165                                                n, k, b, DataType::F32);
166                 }
167             }
168             else
169             {
170                 GEMMLHSMatrixInfo lhs_info_buf;
171                 GEMMRHSMatrixInfo rhs_info_buf;
172                 GEMMLHSMatrixInfo lhs_info_img;
173                 GEMMRHSMatrixInfo rhs_info_img;
174                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, false, true);
175                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, false, true, false, true, false);
176 
177                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
178                                             std::make_pair(lhs_info_buf, rhs_info_buf),
179                                             n, k, b, DataType::F32);
180             }
181         }
182     }
183 }
184 
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)185 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
186 {
187     ARM_COMPUTE_UNUSED(k);
188     ARM_COMPUTE_UNUSED(b);
189 
190     if(m == 1)
191     {
192         const unsigned int h0 = std::max(n / 2, 1U);
193         if(n <= 836.0)
194         {
195             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
196         }
197         else
198         {
199             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
200         }
201     }
202     else if(m < 128)
203     {
204         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
205         if(k >= 512)
206         {
207             return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
208         }
209         else
210         {
211             return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
212         }
213     }
214     else
215     {
216         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217         if(n >= 64)
218         {
219             return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false);
220         }
221         else
222         {
223             if(k >= 512)
224             {
225                 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
226             }
227             else
228             {
229                 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
230             }
231         }
232     }
233 }
234 
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)235 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
236 {
237     ARM_COMPUTE_UNUSED(k);
238     ARM_COMPUTE_UNUSED(b);
239 
240     if(m == 1)
241     {
242         const unsigned int h0 = std::max(n / 2, 1U);
243         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
244     }
245     else
246     {
247         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
248         if(m >= 28)
249         {
250             return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
251         }
252         else
253         {
254             return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, true);
255         }
256     }
257 }
258 } // namespace cl_gemm
259 } // namespace arm_compute
260