1 /* 2 * Copyright (c) 2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H 25 #define SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H 26 27 #include "src/runtime/CL/mlgo/Common.h" 28 #include "src/runtime/CL/mlgo/HeuristicTree.h" 29 30 #include <iostream> 31 #include <map> 32 #include <string> 33 #include <utility> 34 namespace arm_compute 35 { 36 namespace mlgo 37 { 38 /** Query interface */ 39 struct Query 40 { 41 std::string ip_target; /**< The name of the IP target */ 42 DataType data_type; /**< Data type */ 43 unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ 44 unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ 45 unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ 46 unsigned int b; /**< Batch size */ 47 }; 48 49 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs); 50 bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs); 51 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs); 52 53 /** MLGOHeuristics for configuring GEMM kernels */ 54 class MLGOHeuristics 55 { 56 public: 57 /** Constructor */ 58 MLGOHeuristics(); 59 /** Default Destructor */ 60 ~MLGOHeuristics() = default; 61 /** Prevent Copy Construct */ 62 MLGOHeuristics(const MLGOHeuristics &) = delete; 63 /** Prevent Copy Assignment */ 64 MLGOHeuristics &operator=(const MLGOHeuristics &) = delete; 65 /** Default Move Constructor */ 66 MLGOHeuristics(MLGOHeuristics &&) = default; 67 /** Default Move Assignment */ 68 MLGOHeuristics &operator=(MLGOHeuristics &&) = default; 69 /** Query the gemm type 70 * 71 * @param[in] query Query 72 * 73 * @return std::pair<bool, GEMMType> signals if the query succeeded or failed 74 */ 75 std::pair<bool, GEMMType> query_gemm_type(const Query &query) const; 76 /** Query the gemm configuration for native kernel 77 * 78 * @param[in] query Query 79 * 80 * @return std::pair<bool, GEMMConfigNative> bool signals if the query succeeded or failed 81 */ 82 std::pair<bool, GEMMConfigNative> query_gemm_config_native(const Query &query) const; 83 /** Query the gemm configuration for reshaped only rhs kernel 84 * 85 * @param[in] query Query 86 * 87 * @return std::pair<bool, GEMMConfigReshapedOnlyRHS> bool signals if the query succeeded or failed 88 */ 89 std::pair<bool, GEMMConfigReshapedOnlyRHS> query_gemm_config_reshaped_only_rhs(const Query &query) const; 90 /** Query the gemm configuration for reshaped kernel 91 * 92 * @param[in] query Query 93 * 94 * @return std::pair<bool, GEMMConfigReshaped> bool signals if the query succeeded or failed 95 */ 96 std::pair<bool, GEMMConfigReshaped> query_gemm_config_reshaped(const Query &query) const; 97 /** (Re)Load the heuristics from reading a dotmlgo file 98 * 99 * @param[in] filename Path to the dotmlgo file 100 * 101 * @return bool Signals if the reload succeeded or failed 102 */ 103 bool reload_from_file(const std::string &filename); 104 /** (Re)Load the heuristics from reading an input stream 105 * 106 * @param[in] istream Istream containing mlgo heuristics 107 * 108 * @return bool Signals if the reload succeeded or failed 109 */ 110 bool reload_from_stream(std::istream &istream); 111 112 /** Get the heuristic tree from tree id 113 * 114 * @param[in] id Tree id. 115 * 116 * @return HeuristicTree& 117 */ 118 std::pair<bool, HeuristicTree *> get_heuristic_tree(HeuristicTree::TreeID id); 119 /** Add a heuristic tree 120 * @param t Heuristic tree to be added 121 */ 122 bool add_heuristic_tree(HeuristicTree &&t); 123 124 /** Check the validity of the heuristic tree. 125 * 126 * @param id ID of the tree to be checked 127 * 128 * @return bool 129 */ 130 bool check_heuristic_tree(HeuristicTree::TreeID id); 131 132 /** Check the overall validity of the heuristics. 133 * @return bool 134 */ 135 bool check_all() const; 136 137 private: 138 static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/ 139 140 private: 141 // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree 142 std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices; /**< A mapping from TreeID to Index */ 143 std::map<HeuristicTree::Index, HeuristicTree> _trees; /**< A mapping from Index to HeuristicTree */ 144 std::map<HeuristicTree::TreeID, bool> _tree_valid; /**< Result cache of the tree validity checks */ 145 bool _valid; /**< Overall validity */ 146 }; 147 148 } // namespace mlgo 149 } // namespace arm_compute 150 #endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H