1 /* 2 * Copyright (c) 2016-2023 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_CL_GEMM_H 25 #define ARM_COMPUTE_CL_GEMM_H 26 27 #include "arm_compute/core/TensorInfo.h" 28 #include "arm_compute/runtime/CL/CLTensor.h" 29 #include "arm_compute/runtime/CL/CLTypes.h" 30 31 #include "src/gpu/cl/ClCompileContext.h" 32 #include "src/gpu/cl/IClKernel.h" 33 #include "src/gpu/cl/IClOperator.h" 34 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h" 35 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h" 36 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h" 37 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" 38 #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" 39 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" 40 41 #include <memory> 42 43 namespace arm_compute 44 { 45 namespace opencl 46 { 47 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: 48 * 49 * -# @ref kernels::ClGemmReshapeLhsMatrixKernel (only if the RESHAPED is selected by the heuristic model) 50 * -# @ref kernels::ClGemmReshapeRhsMatrixKernel (only if either the RESHAPED or RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) 51 * -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method()) 52 * -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method()) 53 * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) 54 * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method()) 55 */ 56 class ClGemm : public IClOperator 57 { 58 public: 59 /** Constructor */ 60 ClGemm(); 61 /** Initialise the kernel's inputs and output 62 * 63 * Valid data layouts: 64 * - All 65 * 66 * Valid data type configurations: 67 * |src0 |src1 |src2 |dst | 68 * |:------------|:-----------|:---------|:--------------| 69 * |F32 |F32 |F32 |F32 | 70 * |F16 |F16 |F16 |F16 | 71 * 72 * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C]. 73 * 74 * @note All tensors must have the same data type. 75 * 76 * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix 77 * 78 * @note Batched GEMM only allows RHS tensor's rank to be <= 3 79 * @note Batched GEMM only supports broadcasting cases where RHS rank < LHS rank but not the other way around 80 * 81 * @param[in] compile_context The compile context to be used. 82 * @param[in] a First input tensor (Matrix or Vector A). Data types supported: F16/F32 83 * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a. 84 * @param[in] c Third input tensor (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. 85 * @param[out] output Output tensor. Data type supported: same as @p a 86 * @param[in] alpha Weight of the matrix product 87 * @param[in] beta Weight of matrix C 88 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and 89 * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping 90 * in case matrix A and matrix B have been already transformed. 91 */ 92 void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 93 /** Static function to check if given info will lead to a valid configuration 94 * 95 * Similar to ClGemm::configure() 96 * 97 * @return a status 98 */ 99 static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 100 101 // Inherited methods overridden: 102 void run(ITensorPack &tensors) override; 103 void prepare(ITensorPack &constants) override; 104 experimental::MemoryRequirements workspace() const override; 105 106 private: 107 void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 108 void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 109 void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 110 void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 111 112 static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 113 static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 114 static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 115 static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 116 117 private: 118 enum AuxTensorIdx 119 { 120 LhsReshape = 0, 121 RhsReshape, 122 Count 123 }; 124 125 private: 126 std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel> _reshape_lhs_kernel; 127 std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _reshape_rhs_kernel; 128 std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel> _mm_native_kernel; 129 std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel> _mm_reshaped_kernel; 130 std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel; 131 std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel; 132 TensorInfo _tmp_a; 133 TensorInfo _tmp_b; 134 bool _reshape_b_only_on_first_run; 135 CLGEMMKernelType _gemm_kernel_type; 136 bool _is_prepared; 137 experimental::MemoryRequirements _aux_mem{}; 138 }; 139 } // namespace opencl 140 } // namespace arm_compute 141 #endif /* ARM_COMPUTE_CLGEMM_H */ 142