• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33 
34 #include <map>
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42 
CLGEMMReshapedKernelConfigurationBifrost(GPUTarget gpu)43 CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifrost(GPUTarget gpu)
44     : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
51 
52     // Configurations for Mali-G76
53     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
54     {
55         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
56         { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
57         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
58         { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
59         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 },
60         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
61     };
62 
63     // Configurations for Mali-G52
64     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
65     {
66         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32 },
67         { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f16 },
68         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
69         { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
70         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
71         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
72     };
73 
74     // Configurations for Mali-G7x
75     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
76     {
77         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
78         { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
79         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
80         { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
81         { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
82         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
83     };
84 
85     switch(_target)
86     {
87         case GPUTarget::G76:
88             if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
89             {
90                 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
91             }
92             else
93             {
94                 ARM_COMPUTE_ERROR("Not supported data type");
95             }
96         default:
97             if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
98             {
99                 return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
100             }
101             else
102             {
103                 ARM_COMPUTE_ERROR("Not supported data type");
104             }
105     }
106 }
107 
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)108 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
109 {
110     ARM_COMPUTE_UNUSED(k);
111     ARM_COMPUTE_UNUSED(b);
112 
113     if(n <= 4)
114     {
115         return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
116     }
117     else
118     {
119         return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
120     }
121 }
122 
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)123 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
124 {
125     ARM_COMPUTE_UNUSED(k);
126     ARM_COMPUTE_UNUSED(b);
127 
128     if(n <= 4)
129     {
130         return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
131     }
132     else
133     {
134         return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
135     }
136 }
137 
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
139 {
140     ARM_COMPUTE_UNUSED(k);
141     ARM_COMPUTE_UNUSED(b);
142 
143     if(dot8_supported(CLKernelLibrary::get().get_device()))
144     {
145         if(n <= 4)
146         {
147             return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
148         }
149         else
150         {
151             return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true);
152         }
153     }
154     else
155     {
156         if(n <= 4)
157         {
158             return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
159         }
160         else
161         {
162             return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true);
163         }
164     }
165 }
166 
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)167 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
168 {
169     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
170     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
171     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
172     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
173 
174     GEMMLHSMatrixInfo lhs_info_buf;
175     GEMMRHSMatrixInfo rhs_info_buf;
176     GEMMLHSMatrixInfo lhs_info_img;
177     GEMMRHSMatrixInfo rhs_info_img;
178 
179     if(workload <= 274.4000f)
180     {
181         if(r_nk <= 0.7461f)
182         {
183             if(r_mn <= 21.1667f)
184             {
185                 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
186             }
187             else
188             {
189                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
190                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
191 
192                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
193                                            std::make_pair(lhs_info_buf, rhs_info_buf),
194                                            n, k, b, DataType::F32);
195             }
196         }
197         else
198         {
199             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
200             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
201 
202             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
203                                        std::make_pair(lhs_info_buf, rhs_info_buf),
204                                        n, k, b, DataType::F32);
205         }
206     }
207     else
208     {
209         if(r_mk <= 17.3926f)
210         {
211             if(workload <= 542.4000f)
212             {
213                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
214                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
215 
216                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
217                                            std::make_pair(lhs_info_buf, rhs_info_buf),
218                                            n, k, b, DataType::F32);
219             }
220             else
221             {
222                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
223                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
224 
225                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
226                                            std::make_pair(lhs_info_buf, rhs_info_buf),
227                                            n, k, b, DataType::F32);
228             }
229         }
230         else
231         {
232             if(r_nk <= 0.5463f)
233             {
234                 if(workload <= 11767.6001f)
235                 {
236                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
237                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
238 
239                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240                                                std::make_pair(lhs_info_buf, rhs_info_buf),
241                                                n, k, b, DataType::F32);
242                 }
243                 else
244                 {
245                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
246                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
247 
248                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
249                                                std::make_pair(lhs_info_buf, rhs_info_buf),
250                                                n, k, b, DataType::F32);
251                 }
252             }
253             else
254             {
255                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
256                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
257 
258                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
259                                            std::make_pair(lhs_info_buf, rhs_info_buf),
260                                            n, k, b, DataType::F32);
261             }
262         }
263     }
264 }
265 
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)266 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
267 {
268     ARM_COMPUTE_UNUSED(k);
269 
270     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
271 
272     if(workload <= 323.4000f)
273     {
274         return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
275     }
276     else
277     {
278         return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
279     }
280 }
281 
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)282 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
283 {
284     ARM_COMPUTE_UNUSED(k);
285     ARM_COMPUTE_UNUSED(b);
286 
287     GEMMLHSMatrixInfo lhs_info_buf;
288     GEMMRHSMatrixInfo rhs_info_buf;
289     GEMMLHSMatrixInfo lhs_info_img;
290     GEMMRHSMatrixInfo rhs_info_img;
291 
292     // Get lhs_info/rhs_info in case of OpenCL buffer
293     if(n <= 4)
294     {
295         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
296     }
297     else
298     {
299         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
300     }
301 
302     // Get lhs_info/rhs_info in case of OpenCL image
303     // Condition on the GPU workload
304     if((m / 4) * (n / 4) >= 2560)
305     {
306         // Big workload
307         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
308     }
309     else
310     {
311         // Small workload
312         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
313     }
314 
315     const TensorInfo  tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
316     const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
317     const TensorInfo  tensor_reshaped_info(shape, 1, DataType::F32);
318 
319     // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
320     const bool use_cl_image2d = (n <= 4) ? false : true;
321 
322     if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
323     {
324         return std::make_pair(lhs_info_img, rhs_info_img);
325     }
326     else
327     {
328         return std::make_pair(lhs_info_buf, rhs_info_buf);
329     }
330 }
331 
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)332 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
333 {
334     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
335     const float r_mk = static_cast<float>(m) / static_cast<float>(k);
336 
337     if(workload <= 1595.2000f)
338     {
339         if(r_mk <= 2.1044f)
340         {
341             if(workload <= 870.4000f)
342             {
343                 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
344             }
345             else
346             {
347                 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
348             }
349         }
350         else
351         {
352             return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
353         }
354     }
355     else
356     {
357         return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false);
358     }
359 }
360 
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)361 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
362 {
363     ARM_COMPUTE_UNUSED(k);
364     ARM_COMPUTE_UNUSED(b);
365 
366     if(n <= 4)
367     {
368         return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
369     }
370     else
371     {
372         return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
373     }
374 }
375 } // namespace cl_gemm
376 } // namespace arm_compute
377