1 /* 2 * Copyright (c) 2016-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_CLGEMM_H 25 #define ARM_COMPUTE_CLGEMM_H 26 27 #include "arm_compute/runtime/CL/CLTensor.h" 28 #include "arm_compute/runtime/CL/CLTypes.h" 29 #include "arm_compute/runtime/IFunction.h" 30 #include "arm_compute/runtime/IMemoryManager.h" 31 #include "arm_compute/runtime/IWeightsManager.h" 32 #include "arm_compute/runtime/MemoryGroup.h" 33 34 #include <memory> 35 36 namespace arm_compute 37 { 38 class CLCompileContext; 39 class CLGEMMReshapeRHSMatrixKernel; 40 class CLGEMMMatrixMultiplyKernel; 41 class CLGEMMMatrixMultiplyReshapedKernel; 42 class CLGEMMMatrixMultiplyReshapedOnlyRHSKernel; 43 class CLGEMMReshapeLHSMatrixKernel; 44 class ICLTensor; 45 class ITensorInfo; 46 47 namespace weights_transformations 48 { 49 /** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */ 50 class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights 51 { 52 public: 53 /** Default constructor */ 54 CLGEMMReshapeRHSMatrixKernelManaged(); 55 /** Prevent instances of this class from being copied (As this class contains pointers) */ 56 CLGEMMReshapeRHSMatrixKernelManaged(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete; 57 /** Default move constructor */ 58 CLGEMMReshapeRHSMatrixKernelManaged(CLGEMMReshapeRHSMatrixKernelManaged &&) = default; 59 /** Prevent instances of this class from being copied (As this class contains pointers) */ 60 CLGEMMReshapeRHSMatrixKernelManaged &operator=(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete; 61 /** Default move assignment operator */ 62 CLGEMMReshapeRHSMatrixKernelManaged &operator=(CLGEMMReshapeRHSMatrixKernelManaged &&) = default; 63 /** Default desctructor */ 64 ~CLGEMMReshapeRHSMatrixKernelManaged(); 65 //Inherited method override 66 void run() override; 67 68 //Inherited method override 69 void release() override; 70 71 //Inherited method override 72 ICLTensor *get_weights() override; 73 74 //Inherited method override 75 uint32_t uid() override; 76 77 /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel 78 * 79 * @param[in] input Input tensor. Data types supported: All 80 * @param[in] info RHS matrix information to be used for reshaping. 81 */ 82 void configure(const ICLTensor *input, GEMMRHSMatrixInfo info); 83 84 /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel 85 * 86 * @param[in] compile_context The compile context to be used. 87 * @param[in] input Input tensor. Data types supported: All 88 * @param[in] info RHS matrix information to be used for reshaping. 89 */ 90 void configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info); 91 92 private: 93 static constexpr uint32_t _uid{ 0x15 }; 94 CLTensor _output{}; 95 std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _kernel; 96 }; 97 } // namespace weights_transformations 98 99 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: 100 * 101 * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model) 102 * -# @ref CLGEMMReshapeRHSMatrixKernel (only if either the RESHAPED_V1 or RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) 103 * -# @ref CLGEMMMatrixMultiplyKernel (only if either the NATIVE or RESHAPED_V1 is selected by the select_gemm_kernel method()) 104 * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if RESHAPED_V1 is selected by the select_gemm_kernel method()) 105 * -# @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) 106 * 107 */ 108 class CLGEMM : public IFunction 109 { 110 public: 111 /** Default constructor. 112 * 113 * @param[in] memory_manager (Optional) Memory manager. 114 * @param[in] weights_manager (Optional) Weights manager. 115 */ 116 CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); 117 /** Prevent instances of this class from being copied (As this class contains pointers) */ 118 CLGEMM(const CLGEMM &) = delete; 119 /** Default move constructor */ 120 CLGEMM(CLGEMM &&) = default; 121 /** Prevent instances of this class from being copied (As this class contains pointers) */ 122 CLGEMM &operator=(const CLGEMM &) = delete; 123 /** Default move assignment operator */ 124 CLGEMM &operator=(CLGEMM &&) = default; 125 /** Default destructor */ 126 ~CLGEMM(); 127 /** Initialise the kernel's inputs and output 128 * 129 * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C]. 130 * 131 * @note All tensors must have the same data type. 132 * 133 * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix 134 * 135 * @param[in] a First input tensor (Matrix or Vector A). Data types supported: F16/F32 136 * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a. 137 * @param[in] c Third input tensor (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. 138 * @param[out] output Output tensor. Data type supported: same as @p a 139 * @param[in] alpha Weight of the matrix product 140 * @param[in] beta Weight of matrix C 141 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and 142 * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping 143 * in case matrix A and matrix B have been already transformed. 144 */ 145 void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); 146 /** Initialise the kernel's inputs and output 147 * 148 * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C]. 149 * 150 * @note All tensors must have the same data type. 151 * 152 * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix 153 * 154 * @param[in] compile_context The compile context to be used. 155 * @param[in] a First input tensor (Matrix or Vector A). Data types supported: F16/F32 156 * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a. 157 * @param[in] c Third input tensor (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. 158 * @param[out] output Output tensor. Data type supported: same as @p a 159 * @param[in] alpha Weight of the matrix product 160 * @param[in] beta Weight of matrix C 161 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and 162 * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping 163 * in case matrix A and matrix B have been already transformed. 164 */ 165 void configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); 166 /** Static function to check if given info will lead to a valid configuration of @ref CLGEMM. 167 * 168 * @param[in] a First input tensor info (Matrix or Vector A). Data types supported: F16/F32 169 * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a. 170 * @param[in] c Third input tensor info (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a. 171 * @param[in] output Output tensor info. Data type supported: same as @p a 172 * @param[in] alpha Weight of the matrix product 173 * @param[in] beta Weight of matrix C 174 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and 175 * if the reshape of matrix B should happen only for the first run 176 * 177 * @return a status 178 */ 179 static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); 180 181 // Inherited methods overridden: 182 void run() override; 183 void prepare() override; 184 185 private: 186 static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run); 187 188 void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); 189 void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); 190 void configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); 191 void configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, 192 const GEMMInfo &gemm_info); 193 194 static Status validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 195 static Status validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 196 static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 197 static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); 198 199 MemoryGroup _memory_group; 200 IWeightsManager *_weights_manager; 201 std::unique_ptr<CLGEMMMatrixMultiplyKernel> _mm_kernel; 202 std::unique_ptr<CLGEMMReshapeLHSMatrixKernel> _reshape_lhs_kernel; 203 std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _reshape_rhs_kernel; 204 std::unique_ptr<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged> _reshape_rhs_kernel_managed; 205 std::unique_ptr<CLGEMMMatrixMultiplyReshapedKernel> _mm_reshaped_kernel; 206 std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel> _mm_reshaped_only_rhs_kernel; 207 std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel> _mm_reshaped_only_rhs_fallback_kernel; 208 CLTensor _tmp_a; 209 CLTensor _tmp_b; 210 const ICLTensor *_original_b; 211 const ICLTensor *_lhs; 212 ICLTensor *_dst; 213 bool _reshape_b_only_on_first_run; 214 bool _is_prepared; 215 CLGEMMKernelType _gemm_kernel_type; 216 }; 217 } // namespace arm_compute 218 219 #endif /* ARM_COMPUTE_CLGEMM_H */ 220