1 /*
2 * Copyright (c) 2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/core/CL/gemm/CLGEMMHelpers.h"
33
34 #include <map>
35 #include <utility>
36
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
41 using namespace arm_compute::misc::shape_calculator;
42
CLGEMMReshapedOnlyRHSKernelConfigurationValhall(GPUTarget gpu)43 CLGEMMReshapedOnlyRHSKernelConfigurationValhall::CLGEMMReshapedOnlyRHSKernelConfigurationValhall(GPUTarget gpu)
44 : ICLGEMMKernelConfiguration(gpu)
45 {
46 }
47
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationValhall::*)(unsigned int m, unsigned int n, unsigned int k,
51 unsigned int b);
52
53 // Configurations for Mali-G77
54 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G77 =
55 {
56 { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32 },
57 { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16 },
58 { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
59 { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
60 { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 },
61 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }
62 };
63
64 switch(_target)
65 {
66 case GPUTarget::G77:
67 default:
68 if(gemm_configs_G77.find(data_type) != gemm_configs_G77.end())
69 {
70 return (this->*gemm_configs_G77[data_type])(m, n, k, b);
71 }
72 else
73 {
74 ARM_COMPUTE_ERROR("Not supported data type");
75 }
76 }
77 }
78
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)79 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
80 {
81 if(m == 1)
82 {
83 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
84 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
85
86 if(r_mk <= 0.0064484127797186375)
87 {
88 if(r_mn <= 0.0028273810748942196)
89 {
90 GEMMLHSMatrixInfo lhs_info_buf;
91 GEMMRHSMatrixInfo rhs_info_buf;
92 GEMMLHSMatrixInfo lhs_info_img;
93 GEMMRHSMatrixInfo rhs_info_img;
94
95 const unsigned int h0 = std::max(n / 4, 1U);
96 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, false, true);
97 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true, false);
98
99 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
100 std::make_pair(lhs_info_buf, rhs_info_buf),
101 n, k, b, DataType::F32);
102 }
103 else
104 {
105 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, false, true, false, false, false);
106 }
107 }
108 else
109 {
110 if(r_mk <= 0.020312500186264515)
111 {
112 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, false, false);
113 }
114 else
115 {
116 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, true, false);
117 }
118 }
119 }
120 else
121 {
122 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
123 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
124 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
125
126 if(workload <= 1999.2000122070312)
127 {
128 if(workload <= 747.1999816894531)
129 {
130 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
131 }
132 else
133 {
134 GEMMLHSMatrixInfo lhs_info_buf;
135 GEMMRHSMatrixInfo rhs_info_buf;
136 GEMMLHSMatrixInfo lhs_info_img;
137 GEMMRHSMatrixInfo rhs_info_img;
138 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
139 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
140
141 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
142 std::make_pair(lhs_info_buf, rhs_info_buf),
143 n, k, b, DataType::F32);
144 }
145 }
146 else
147 {
148 if(r_mn <= 0.03348214365541935)
149 {
150 if(r_mk <= 0.028125000186264515)
151 {
152 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
153 }
154 else
155 {
156 GEMMLHSMatrixInfo lhs_info_buf;
157 GEMMRHSMatrixInfo rhs_info_buf;
158 GEMMLHSMatrixInfo lhs_info_img;
159 GEMMRHSMatrixInfo rhs_info_img;
160 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
161 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
162
163 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
164 std::make_pair(lhs_info_buf, rhs_info_buf),
165 n, k, b, DataType::F32);
166 }
167 }
168 else
169 {
170 GEMMLHSMatrixInfo lhs_info_buf;
171 GEMMRHSMatrixInfo rhs_info_buf;
172 GEMMLHSMatrixInfo lhs_info_img;
173 GEMMRHSMatrixInfo rhs_info_img;
174 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, false, true);
175 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, false, true, false, true, false);
176
177 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
178 std::make_pair(lhs_info_buf, rhs_info_buf),
179 n, k, b, DataType::F32);
180 }
181 }
182 }
183 }
184
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)185 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
186 {
187 ARM_COMPUTE_UNUSED(k);
188 ARM_COMPUTE_UNUSED(b);
189
190 if(m == 1)
191 {
192 const unsigned int h0 = std::max(n / 2, 1U);
193 if(n <= 836.0)
194 {
195 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
196 }
197 else
198 {
199 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
200 }
201 }
202 else if(m < 128)
203 {
204 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
205 if(k >= 512)
206 {
207 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
208 }
209 else
210 {
211 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
212 }
213 }
214 else
215 {
216 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217 if(n >= 64)
218 {
219 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false);
220 }
221 else
222 {
223 if(k >= 512)
224 {
225 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
226 }
227 else
228 {
229 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
230 }
231 }
232 }
233 }
234
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)235 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
236 {
237 ARM_COMPUTE_UNUSED(k);
238 ARM_COMPUTE_UNUSED(b);
239
240 if(m == 1)
241 {
242 const unsigned int h0 = std::max(n / 2, 1U);
243 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
244 }
245 else
246 {
247 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
248 if(m >= 28)
249 {
250 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
251 }
252 else
253 {
254 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, true);
255 }
256 }
257 }
258 } // namespace cl_gemm
259 } // namespace arm_compute
260