1 /* 2 * Copyright (c) 2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H 25 #define SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H 26 27 #include "arm_compute/core/GPUTarget.h" 28 #include "arm_compute/core/Types.h" 29 #include "arm_compute/runtime/CL/CLTypes.h" 30 31 namespace arm_compute 32 { 33 namespace cl_gemm 34 { 35 namespace auto_heuristics 36 { 37 /** A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and default heuristics */ 38 39 /** Common query */ 40 struct CommonQuery 41 { 42 GPUTarget gpu_target; /**< Which @ref GPUTarget to query about */ 43 DataType data_type; /**< Data type */ 44 unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ 45 unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ 46 unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ 47 unsigned int b; /**< Batch size */ 48 }; 49 50 /** Result of querying about GEMM type ( @ref CLGEMMKernelType) */ 51 struct GEMMTypeResult 52 { GEMMTypeResultGEMMTypeResult53 GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type) 54 : valid{ valid }, gemm_type{ gemm_type } 55 { 56 } 57 /** Test if the result is valid */ 58 operator bool() const 59 { 60 return valid; 61 } 62 bool valid; /** If the result is valid */ 63 CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */ 64 }; 65 66 /** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */ 67 struct GEMMConfigResult 68 { GEMMConfigResultGEMMConfigResult69 GEMMConfigResult(bool valid, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info) 70 : valid{ valid }, lhs_info{ lhs_info }, rhs_info{ rhs_info } 71 { 72 } 73 /** Test if the result is valid */ 74 operator bool() const 75 { 76 return valid; 77 } 78 bool valid; /** If the result is valid */ 79 GEMMLHSMatrixInfo lhs_info; /** @ref GEMMLHSMatrixInfo */ 80 GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */ 81 }; 82 83 /** Select gemm type based on mlgo heuristics 84 * @param query Query 85 * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run 86 * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise 87 */ 88 GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); 89 90 /** Select gemm type based on default heuristics 91 * @param query Query 92 * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run 93 * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise 94 */ 95 GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); 96 97 /** Select gemm config based on mlgo heuristics 98 * @param query Query 99 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 100 */ 101 GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); 102 103 /** Select gemm config based on default heuristics 104 * @param query Query 105 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 106 */ 107 GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); 108 109 /** Select gemm config based on mlgo heuristics 110 * @param query Query 111 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 112 */ 113 GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); 114 115 /** Select gemm config based on default heuristics 116 * @param query Query 117 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 118 */ 119 GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); 120 121 /** Select gemm config based on mlgo heuristics 122 * @param query Query 123 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 124 */ 125 GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query); 126 127 /** Select gemm config based on default heuristics 128 * @param query Query 129 * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise 130 */ 131 GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query); 132 133 } // namespace auto_heuristics 134 } // namespace cl_gemm 135 } // namespace arm_compute 136 137 #endif // SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H