• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/gpu/cl/operators/ClGemm.h"
25 
26 #include "arm_compute/core/CL/CLKernelLibrary.h"
27 #include "arm_compute/core/CL/ICLTensor.h"
28 #include "arm_compute/core/Error.h"
29 #include "arm_compute/core/GPUTarget.h"
30 #include "arm_compute/core/Helpers.h"
31 #include "arm_compute/core/KernelDescriptors.h"
32 #include "arm_compute/core/Log.h"
33 #include "arm_compute/core/TensorInfo.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
36 #include "arm_compute/core/Validate.h"
37 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
38 #include "arm_compute/runtime/CL/CLScheduler.h"
39 #include "arm_compute/runtime/ITensorAllocator.h"
40 
41 #include "arm_compute/core/experimental/IPostOp.h"
42 #include "src/core/helpers/AutoConfiguration.h"
43 #include "src/core/helpers/MemoryHelpers.h"
44 #include "src/core/utils/helpers/float_ops.h"
45 #include "src/gpu/cl/IClKernel.h"
46 #include "src/gpu/cl/utils/ClAuxTensorHandler.h"
47 #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
48 #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
49 
50 #include "src/common/utils/Log.h"
51 #include "support/Cast.h"
52 #include "utils/TypePrinter.h"
53 
54 namespace arm_compute
55 {
56 namespace opencl
57 {
58 using namespace arm_compute::misc::shape_calculator;
59 using namespace arm_compute::cl_gemm;
60 using namespace arm_compute::experimental;
61 using namespace arm_compute::utils::cast;
62 using namespace arm_compute::opencl::kernels;
63 
64 namespace
65 {
validate_gemm_kernel(CLGEMMKernelType kernel_type)66 inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67 {
68     return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
69 }
70 //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
auto_select_gemm_kernel(auto_heuristics::CommonQuery query,bool reshape_b_only_on_first_run,bool constant_weights)71 inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
72 {
73     if(!constant_weights)
74     {
75         return CLGEMMKernelType::NATIVE;
76     }
77 
78     auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79     if(bool(gemm_kernel))
80     {
81         if(validate_gemm_kernel(gemm_kernel.gemm_type))
82         {
83             ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84             return gemm_kernel.gemm_type;
85         }
86     }
87     gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89     return gemm_kernel.gemm_type;
90 }
91 // Validate lhs_info and rhs_info for reshaped only rhs kernel
validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,GEMMKernelInfo gemm_kernel_info)92 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93                                                     const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94 {
95     // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96     TensorInfo tmp_b_info{};
97     // Validate reshape RHS kernel
98     auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99     if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100     {
101         return false;
102     }
103     // Validate mm kernel
104     gemm_kernel_info.lhs_info  = lhs_info;
105     gemm_kernel_info.rhs_info  = rhs_info;
106     gemm_kernel_info.has_pad_y = false;
107     if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108     {
109         return false;
110     }
111     gemm_kernel_info.has_pad_y = true;
112     if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113     {
114         return false;
115     }
116     return true;
117 }
118 
119 //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query,GEMMKernelInfo kernel_info,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output)120 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121                                                                                                  const ITensorInfo *b,
122                                                                                                  const ITensorInfo *c, const ITensorInfo *output)
123 {
124     auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
125     if(config)
126     {
127         if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128         {
129             ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130             return { config.lhs_info, config.rhs_info };
131         }
132     }
133     config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
134     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135     return { config.lhs_info, config.rhs_info };
136 }
137 
138 // Validate lhs_info and rhs_info for reshaped kernel
validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,GEMMKernelInfo gemm_kernel_info,bool reinterpret_input_as_3d)139 inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140                                            const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141 {
142     // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143     TensorInfo tmp_a_info{};
144     TensorInfo tmp_b_info{};
145 
146     // Validate reshape LHS kernel
147     auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148     if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149     {
150         return false;
151     }
152 
153     // Validate reshape RHS kernel
154     auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155     if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156     {
157         return false;
158     }
159     // Validate mm kernel
160     gemm_kernel_info.lhs_info = lhs_info;
161     gemm_kernel_info.rhs_info = rhs_info;
162     if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163     {
164         return false;
165     }
166     return true;
167 }
168 
169 //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query,GEMMKernelInfo kernel_info,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,bool reinterpret_input_as_3d)170 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171                                                                                         const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172 {
173     auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
174     if(config)
175     {
176         if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177         {
178             ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179             return { config.lhs_info, config.rhs_info };
180         }
181     }
182     config = auto_heuristics::select_default_gemm_config_reshaped(query);
183     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184     return { config.lhs_info, config.rhs_info };
185 }
186 } // namespace
187 
ClGemm()188 ClGemm::ClGemm()
189     : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
190       _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
191       _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
192       _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193       _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
194       _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
195       _tmp_a(),
196       _tmp_b(),
197       _reshape_b_only_on_first_run(false),
198       _gemm_kernel_type(CLGEMMKernelType::NATIVE),
199       _is_prepared(false),
200       _aux_mem(AuxTensorIdx::Count)
201 {
202 }
203 
configure_native(const CLCompileContext & compile_context,ITensorInfo * a,ITensorInfo * b,ITensorInfo * c,ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)204 void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
205                               const GEMMInfo &gemm_info)
206 {
207     DataType           data_type               = a->data_type();
208     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
209     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
210     const unsigned int n                       = b->dimension(0);
211     const unsigned int k                       = a->dimension(0);
212     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
213     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
214     const GPUTarget    gpu_target              = CLScheduler::get().target();
215     bool               broadcast_bias          = gemm_info.broadcast_bias();
216 
217     GEMMKernelInfo kernel_info;
218     kernel_info.m                       = m;
219     kernel_info.n                       = n;
220     kernel_info.k                       = k;
221     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
222     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
223     kernel_info.broadcast_bias          = broadcast_bias;
224     kernel_info.activation_info         = gemm_info.activation_info();
225     kernel_info.post_ops                = gemm_info.post_ops();
226 
227     // Set the target for the kernels
228     _mm_native_kernel->set_target(gpu_target);
229 
230     auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
231 
232     // Configure and tune matrix multiply kernel
233     _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
234 }
235 
configure_reshaped(const CLCompileContext & compile_context,ITensorInfo * a,ITensorInfo * b,ITensorInfo * c,ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)236 void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237                                 const GEMMInfo &gemm_info)
238 {
239     DataType           data_type               = a->data_type();
240     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
242     const unsigned int n                       = b->dimension(0);
243     const unsigned int k                       = a->dimension(0);
244     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
245     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
246     const GPUTarget    gpu_target              = CLScheduler::get().target();
247     bool               broadcast_bias          = gemm_info.broadcast_bias();
248 
249     GEMMKernelInfo kernel_info;
250     kernel_info.m                       = m;
251     kernel_info.n                       = n;
252     kernel_info.k                       = k;
253     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
254     kernel_info.reinterpret_input_as_3d = false;
255     kernel_info.broadcast_bias          = broadcast_bias;
256     kernel_info.activation_info         = gemm_info.activation_info();
257     kernel_info.post_ops                = gemm_info.post_ops();
258 
259     // Set the target for the kernels
260     _reshape_lhs_kernel->set_target(gpu_target);
261     _mm_reshaped_kernel->set_target(gpu_target);
262 
263     GEMMLHSMatrixInfo lhs_info{};
264     GEMMRHSMatrixInfo rhs_info{};
265 
266     // Pick up the GEMM configuration
267     std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
268                                                                     c, output, gemm_info.reinterpret_input_as_3d());
269 
270     _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
271     _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
272 
273     // Configure and tune matrix multiply kernel
274     _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
275 
276     // Request memory for LHS and RHS reshape matrix
277     _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
278     _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
279 }
280 
configure_reshaped_only_rhs(const CLCompileContext & compile_context,ITensorInfo * a,ITensorInfo * b,ITensorInfo * c,ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)281 void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
282                                          const GEMMInfo &gemm_info)
283 {
284     DataType           data_type               = a->data_type();
285     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
286     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
287     const unsigned int n                       = b->dimension(0);
288     const unsigned int k                       = a->dimension(0);
289     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
290     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
291     const GPUTarget    gpu_target              = CLScheduler::get().target();
292     bool               broadcast_bias          = gemm_info.broadcast_bias();
293 
294     GEMMKernelInfo kernel_info;
295     kernel_info.m                       = m;
296     kernel_info.n                       = n;
297     kernel_info.k                       = k;
298     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
299     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
300     kernel_info.broadcast_bias          = broadcast_bias;
301     kernel_info.activation_info         = gemm_info.activation_info();
302     kernel_info.post_ops                = gemm_info.post_ops();
303 
304     // Set the target for the kernels
305     _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
306 
307     GEMMLHSMatrixInfo lhs_info{};
308     GEMMRHSMatrixInfo rhs_info{};
309 
310     // Pick up the GEMM configuration
311     std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
312 
313     // Transpose matrix
314     _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
315 
316     // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
317     // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
318     // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
319 
320     // Configure matrix multiply kernel with no y padding support
321     kernel_info.has_pad_y = false;
322     _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
323 
324     // Request memory for RHS reshape matrix
325     _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
326 }
327 
configure_reshaped_only_rhs_mmul(const CLCompileContext & compile_context,ITensorInfo * a,ITensorInfo * b,ITensorInfo * c,ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)328 void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
329                                               const GEMMInfo &gemm_info)
330 {
331     DataType           data_type               = a->data_type();
332     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
333     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
334     const unsigned int n                       = b->dimension(0);
335     const unsigned int k                       = a->dimension(0);
336     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
337     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
338     const GPUTarget    gpu_target              = CLScheduler::get().target();
339     bool               broadcast_bias          = gemm_info.broadcast_bias();
340 
341     GEMMKernelInfo kernel_info;
342     kernel_info.m                       = m;
343     kernel_info.n                       = n;
344     kernel_info.k                       = k;
345     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
346     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
347     kernel_info.broadcast_bias          = broadcast_bias;
348     kernel_info.activation_info         = gemm_info.activation_info();
349     kernel_info.post_ops                = gemm_info.post_ops();
350 
351     // Set the target for the kernels
352     _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
353 
354     GEMMLHSMatrixInfo lhs_info{};
355     GEMMRHSMatrixInfo rhs_info{};
356 
357     // Pick up the GEMM configuration
358     auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
359     lhs_info         = gemm_config.lhs_info;
360     rhs_info         = gemm_config.rhs_info;
361     // Force H0 to 4 in order to use the MMUL extension
362     rhs_info.h0 = 4;
363 
364     // Reshape Rhs matrix
365     _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
366 
367     // Configure matrix multiply kernel with no y padding support
368     kernel_info.has_pad_y = false;
369     _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
370 
371     // Request memory for RHS reshape matrix
372     _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
373 }
374 
validate_native(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)375 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
376 {
377     ARM_COMPUTE_UNUSED(alpha);
378     ARM_COMPUTE_UNUSED(output);
379 
380     // Get the GPU target
381     const GPUTarget    gpu_target              = CLScheduler::get().target();
382     DataType           data_type               = a->data_type();
383     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
384     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
385     const unsigned int n                       = b->dimension(0);
386     const unsigned int k                       = a->dimension(0);
387     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
388     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
389     const bool         broadcast_bias          = gemm_info.broadcast_bias();
390 
391     GEMMKernelInfo kernel_info;
392     kernel_info.m                       = m;
393     kernel_info.n                       = n;
394     kernel_info.k                       = k;
395     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
396     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
397     kernel_info.broadcast_bias          = broadcast_bias;
398     kernel_info.activation_info         = gemm_info.activation_info();
399     kernel_info.post_ops                = gemm_info.post_ops();
400 
401     auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
402 
403     // Validate matrix multiply
404     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
405 
406     return Status{};
407 }
408 
validate_reshaped(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)409 Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
410 {
411     ARM_COMPUTE_UNUSED(alpha);
412     ARM_COMPUTE_UNUSED(output);
413 
414     TensorInfo tmp_a_info{};
415     TensorInfo tmp_b_info{};
416 
417     // Get the GPU target
418     const GPUTarget    gpu_target              = CLScheduler::get().target();
419     DataType           data_type               = a->data_type();
420     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
421     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
422     const unsigned int n                       = b->dimension(0);
423     const unsigned int k                       = a->dimension(0);
424     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
425     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
426     const bool         broadcast_bias          = gemm_info.broadcast_bias();
427 
428     GEMMKernelInfo kernel_info;
429     kernel_info.m                       = m;
430     kernel_info.n                       = n;
431     kernel_info.k                       = k;
432     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
433     kernel_info.reinterpret_input_as_3d = false;
434     kernel_info.broadcast_bias          = broadcast_bias;
435     kernel_info.activation_info         = gemm_info.activation_info();
436     kernel_info.post_ops                = gemm_info.post_ops();
437 
438     GEMMLHSMatrixInfo lhs_info;
439     GEMMRHSMatrixInfo rhs_info;
440 
441     // Pick up the GEMM configuration
442     // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
443     const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
444     lhs_info               = gemm_config.lhs_info;
445     rhs_info               = gemm_config.rhs_info;
446 
447     auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
448     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
449 
450     auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
451     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
452 
453     // Validate matrix multiply
454     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
455 
456     return Status{};
457 }
458 
validate_reshaped_only_rhs(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)459 Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
460 {
461     ARM_COMPUTE_UNUSED(alpha);
462     ARM_COMPUTE_UNUSED(output);
463 
464     TensorInfo tmp_b_info{};
465 
466     // Get the GPU target
467     const GPUTarget    gpu_target              = CLScheduler::get().target();
468     const DataType     data_type               = a->data_type();
469     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
470     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
471     const unsigned int n                       = b->dimension(0);
472     const unsigned int k                       = a->dimension(0);
473     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
474     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
475     const bool         broadcast_bias          = gemm_info.broadcast_bias();
476 
477     GEMMKernelInfo kernel_info;
478     kernel_info.m                       = m;
479     kernel_info.n                       = n;
480     kernel_info.k                       = k;
481     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
482     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
483     kernel_info.broadcast_bias          = broadcast_bias;
484     kernel_info.activation_info         = gemm_info.activation_info();
485     kernel_info.post_ops                = gemm_info.post_ops();
486 
487     GEMMLHSMatrixInfo lhs_info;
488     GEMMRHSMatrixInfo rhs_info;
489 
490     // Pick up the GEMM configuration
491     // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
492     const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
493     lhs_info               = gemm_config.lhs_info;
494     rhs_info               = gemm_config.rhs_info;
495 
496     auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
497     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
498 
499     // Validate matrix multiply
500     kernel_info.has_pad_y = false;
501     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
502 
503     kernel_info.has_pad_y = true;
504     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
505 
506     return Status{};
507 }
508 
validate_reshaped_only_rhs_mmul(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)509 Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
510 {
511     ARM_COMPUTE_UNUSED(alpha);
512     ARM_COMPUTE_UNUSED(output);
513     TensorInfo tmp_b_info{};
514 
515     // Get the GPU target
516     const GPUTarget    gpu_target              = CLScheduler::get().target();
517     const DataType     data_type               = a->data_type();
518     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
519     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
520     const unsigned int n                       = b->dimension(0);
521     const unsigned int k                       = a->dimension(0);
522     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
523     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
524     const bool         broadcast_bias          = gemm_info.broadcast_bias();
525 
526     GEMMKernelInfo kernel_info;
527     kernel_info.m                       = m;
528     kernel_info.n                       = n;
529     kernel_info.k                       = k;
530     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
531     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
532     kernel_info.broadcast_bias          = broadcast_bias;
533     kernel_info.activation_info         = gemm_info.activation_info();
534     kernel_info.post_ops                = gemm_info.post_ops();
535 
536     GEMMLHSMatrixInfo lhs_info;
537     GEMMRHSMatrixInfo rhs_info;
538 
539     // Pick up the GEMM configuration
540     // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
541     const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
542     lhs_info               = gemm_config.lhs_info;
543     rhs_info               = gemm_config.rhs_info;
544     // Force H0 to 4 in order to use the MMUL extension
545     rhs_info.h0 = 4;
546 
547     auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
548     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
549 
550     // Validate matrix multiply
551     kernel_info.has_pad_y = false;
552     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
553 
554     return Status{};
555 }
556 
configure(const CLCompileContext & compile_context,ITensorInfo * a,ITensorInfo * b,ITensorInfo * c,ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)557 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
558 {
559     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
560 
561     // Perform validation step
562     ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
563     ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
564 
565     // Check if we need to reshape the matrix B only on the first run
566     _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
567     _is_prepared                 = gemm_info.retain_internal_weights();
568 
569     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
570     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
571     const unsigned int n                       = b->dimension(0);
572     const unsigned int k                       = a->dimension(0);
573     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
574 
575     // Select GEMMType
576     _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
577                                                 b->are_values_constant());
578 
579     const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
580 
581     ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
582 
583     switch(_gemm_kernel_type)
584     {
585         case CLGEMMKernelType::NATIVE:
586         {
587             configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
588             break;
589         }
590         case CLGEMMKernelType::RESHAPED:
591         {
592             configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
593             break;
594         }
595         case CLGEMMKernelType::RESHAPED_ONLY_RHS:
596         {
597             configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
598             break;
599         }
600         case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
601         {
602             configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
603             break;
604         }
605         default:
606         {
607             ARM_COMPUTE_ERROR("GEMMType not supported");
608         }
609     }
610 }
611 
validate(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * output,float alpha,float beta,const GEMMInfo & gemm_info)612 Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
613 {
614     // Get the GPU target
615     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
616     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
617     const unsigned int n                       = b->dimension(0);
618     const unsigned int k                       = a->dimension(0);
619     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
620 
621     // Check data type early because the auto_select_gemm_kernel has assertions on supported data types
622     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
623 
624     // Select GEMMType
625     CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
626     {
627         CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
628     },
629     gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
630 
631     const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
632 
633     const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
634 
635     switch(gemm_kernel_type)
636     {
637         case CLGEMMKernelType::NATIVE:
638         {
639             ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
640             break;
641         }
642         case CLGEMMKernelType::RESHAPED:
643         {
644             ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
645             break;
646         }
647         case CLGEMMKernelType::RESHAPED_ONLY_RHS:
648         {
649             ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
650             break;
651         }
652         case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
653         {
654             ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
655             break;
656         }
657         default:
658         {
659             ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
660         }
661     }
662 
663     return Status{};
664 }
665 
run(ITensorPack & tensors)666 void ClGemm::run(ITensorPack &tensors)
667 {
668     const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
669     const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
670     ITensor       *dst = tensors.get_tensor(ACL_DST);
671 
672     ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst);
673 
674     CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
675     CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
676 
677     // Prepare the consts if needed
678     prepare(tensors);
679 
680     // Run matrix multiply kernel
681     switch(_gemm_kernel_type)
682     {
683         case CLGEMMKernelType::NATIVE:
684         {
685             CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
686             break;
687         }
688         case CLGEMMKernelType::RESHAPED:
689         {
690             // Run interleave kernel
691             ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
692             CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
693 
694             if(!_reshape_b_only_on_first_run)
695             {
696                 // Run transpose kernel
697                 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
698                 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
699             }
700             // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
701             ITensorPack gemm_reshaped_pack(tensors);
702             gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
703             gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
704 
705             if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
706             {
707                 CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
708             }
709             break;
710         }
711         case CLGEMMKernelType::RESHAPED_ONLY_RHS:
712         {
713             if(!_reshape_b_only_on_first_run)
714             {
715                 // Run transpose kernel
716                 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
717                 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
718             }
719             // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
720             // Check if the lhs or dst tensors have padding
721             const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
722             const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
723             bool               has_pad_y           = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
724 
725             // Copy original tensor pack and overwrite rhs with reshaped counterpart
726             ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
727             gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
728 
729             if(has_pad_y)
730             {
731                 ARM_COMPUTE_ERROR_ON(has_pad_y);
732             }
733             else
734             {
735                 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
736             }
737             break;
738         }
739         case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
740         {
741             if(!_reshape_b_only_on_first_run)
742             {
743                 // Run transpose kernel
744                 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
745                 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
746             }
747             // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
748             // Check if the lhs or dst tensors have padding
749             const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
750             const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
751             bool               has_pad_y           = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
752 
753             // Copy original tensor pack and overwrite rhs with reshaped counterpart
754             ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
755             gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
756 
757             if(has_pad_y)
758             {
759                 ARM_COMPUTE_ERROR_ON(has_pad_y);
760             }
761             else
762             {
763                 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
764             }
765             break;
766         }
767         default:
768         {
769             ARM_COMPUTE_ERROR("GEMMType not supported");
770         }
771     }
772 }
773 
prepare(ITensorPack & constants)774 void ClGemm::prepare(ITensorPack &constants)
775 {
776     if(!_is_prepared)
777     {
778         const ITensor *src1    = constants.get_const_tensor(ACL_SRC_1);
779         ICLTensor     *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
780 
781         // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
782         if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
783         {
784             ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
785 
786             CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
787             ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
788 
789             ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
790             CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
791         }
792         _is_prepared = true;
793     }
794 }
795 
workspace() const796 experimental::MemoryRequirements ClGemm::workspace() const
797 {
798     return _aux_mem;
799 }
800 } // namespace opencl
801 } // namespace arm_compute
802