• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "src/core/CL/gemm/CLGEMMHelpers.h"
30 
31 #include <map>
32 #include <utility>
33 
34 namespace arm_compute
35 {
36 namespace cl_gemm
37 {
CLGEMMNativeKernelConfigurationBifrost(GPUTarget gpu)38 CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(GPUTarget gpu)
39     : ICLGEMMKernelConfiguration(gpu)
40 {
41 }
42 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)43 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
44 {
45     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
46                                              unsigned int b);
47 
48     // Configurations for Mali-G71
49     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G71 =
50     {
51         { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 },
52         { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
53         { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
54         { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
55         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }
56     };
57 
58     // Configurations for Mali-G76
59     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
60     {
61         { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 },
62         { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
63         { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
64         { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
65         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }
66     };
67 
68     // Default configurations
69     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_default =
70     {
71         { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 },
72         { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
73         { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
74         { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
75         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }
76     };
77 
78     switch(_target)
79     {
80         case GPUTarget::G71:
81             if(gemm_configs_G71.find(data_type) != gemm_configs_G71.end())
82             {
83                 return (this->*gemm_configs_G71[data_type])(m, n, k, b);
84             }
85             else
86             {
87                 ARM_COMPUTE_ERROR("Not supported data type");
88             }
89         case GPUTarget::G76:
90             if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
91             {
92                 return (this->*gemm_configs_G76[data_type])(m, n, k, b);
93             }
94             else
95             {
96                 ARM_COMPUTE_ERROR("Not supported data type");
97             }
98         default:
99             if(gemm_configs_default.find(data_type) != gemm_configs_default.end())
100             {
101                 return (this->*gemm_configs_default[data_type])(m, n, k, b);
102             }
103             else
104             {
105                 ARM_COMPUTE_ERROR("Not supported data type");
106             }
107     }
108 }
109 
configure_G71_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)110 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
111 {
112     ARM_COMPUTE_UNUSED(k);
113     ARM_COMPUTE_UNUSED(b);
114 
115     if(m == 1)
116     {
117         if(n < 2048)
118         {
119             return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
120         }
121         else if(n >= 2048 && n < 8192)
122         {
123             return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false);
124         }
125         else
126         {
127             return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false);
128         }
129     }
130     else
131     {
132         return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1, false, false, false, false);
133     }
134 }
135 
configure_G71_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)136 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
137 {
138     ARM_COMPUTE_UNUSED(k);
139     ARM_COMPUTE_UNUSED(b);
140 
141     if(dot8_supported(CLKernelLibrary::get().get_device()))
142     {
143         if(m == 1)
144         {
145             if(n < 2048)
146             {
147                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
148             }
149             else if(n >= 2048 && n < 16384)
150             {
151                 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
152             }
153             else
154             {
155                 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
156             }
157         }
158         else
159         {
160             if(m < 64)
161             {
162                 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
163             }
164             else
165             {
166                 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
167             }
168         }
169     }
170     else
171     {
172         if(m == 1)
173         {
174             if(n < 8192)
175             {
176                 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
177             }
178             else
179             {
180                 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
181             }
182         }
183         else
184         {
185             return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false);
186         }
187     }
188 }
189 
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)190 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
191 {
192     ARM_COMPUTE_UNUSED(k);
193     ARM_COMPUTE_UNUSED(b);
194 
195     if(m == 1)
196     {
197         if(n > 4196)
198         {
199             return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1, false, false, false, false);
200         }
201         else
202         {
203             if(k < 2048)
204             {
205                 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1, false, false, false, false);
206             }
207             else if(k >= 2048 && k < 16384)
208             {
209                 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
210             }
211             else
212             {
213                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1, false, false, false, false);
214             }
215         }
216     }
217     else
218     {
219         return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1, false, false, false, false);
220     }
221 }
222 
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)223 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
224 {
225     ARM_COMPUTE_UNUSED(k);
226     ARM_COMPUTE_UNUSED(b);
227 
228     if(m == 1)
229     {
230         if(n < 2048)
231         {
232             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
233         }
234         else if(n >= 2048 && n < 16384)
235         {
236             return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
237         }
238         else
239         {
240             return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
241         }
242     }
243     else
244     {
245         if(m < 64)
246         {
247             return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
248         }
249         else
250         {
251             return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
252         }
253     }
254 }
255 
configure_default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)256 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
257 {
258     ARM_COMPUTE_UNUSED(k);
259     ARM_COMPUTE_UNUSED(b);
260 
261     return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1, false, false, false, false);
262 }
263 
configure_default_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)264 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
265 {
266     ARM_COMPUTE_UNUSED(k);
267     ARM_COMPUTE_UNUSED(b);
268 
269     return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
270 }
271 } // namespace cl_gemm
272 } // namespace arm_compute