• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
25 #define SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
26 
27 #include "src/runtime/CL/mlgo/Common.h"
28 #include "src/runtime/CL/mlgo/HeuristicTree.h"
29 
30 #include <iostream>
31 #include <map>
32 #include <string>
33 #include <utility>
34 namespace arm_compute
35 {
36 namespace mlgo
37 {
38 /** Query interface */
39 struct Query
40 {
41     std::string  ip_target; /**< The name of the IP target */
42     DataType     data_type; /**< Data type */
43     unsigned int m;         /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
44     unsigned int n;         /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
45     unsigned int k;         /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
46     unsigned int b;         /**< Batch size */
47 };
48 
49 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs);
50 bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs);
51 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs);
52 
53 /** MLGOHeuristics for configuring GEMM kernels */
54 class MLGOHeuristics
55 {
56 public:
57     /** Constructor */
58     MLGOHeuristics();
59     /** Default Destructor */
60     ~MLGOHeuristics() = default;
61     /** Prevent Copy Construct */
62     MLGOHeuristics(const MLGOHeuristics &) = delete;
63     /** Prevent Copy Assignment */
64     MLGOHeuristics &operator=(const MLGOHeuristics &) = delete;
65     /** Default Move Constructor */
66     MLGOHeuristics(MLGOHeuristics &&) = default;
67     /** Default Move Assignment */
68     MLGOHeuristics &operator=(MLGOHeuristics &&) = default;
69     /** Query the gemm type
70      *
71      * @param[in] query Query
72      *
73      * @return std::pair<bool, GEMMType>  signals if the query succeeded or failed
74      */
75     std::pair<bool, GEMMType> query_gemm_type(const Query &query) const;
76     /** Query the gemm configuration for native kernel
77      *
78      * @param[in] query Query
79      *
80      * @return std::pair<bool, GEMMConfigNative>   bool signals if the query succeeded or failed
81      */
82     std::pair<bool, GEMMConfigNative> query_gemm_config_native(const Query &query) const;
83     /** Query the gemm configuration for reshaped only rhs kernel
84      *
85      * @param[in] query Query
86      *
87      * @return std::pair<bool, GEMMConfigReshapedOnlyRHS>   bool signals if the query succeeded or failed
88      */
89     std::pair<bool, GEMMConfigReshapedOnlyRHS> query_gemm_config_reshaped_only_rhs(const Query &query) const;
90     /** Query the gemm configuration for reshaped kernel
91      *
92      * @param[in] query Query
93      *
94      * @return std::pair<bool, GEMMConfigReshaped>   bool signals if the query succeeded or failed
95      */
96     std::pair<bool, GEMMConfigReshaped> query_gemm_config_reshaped(const Query &query) const;
97     /** (Re)Load the heuristics from reading a dotmlgo file
98      *
99      * @param[in] filename Path to the dotmlgo file
100      *
101      * @return bool Signals if the reload succeeded or failed
102      */
103     bool reload_from_file(const std::string &filename);
104     /** (Re)Load the heuristics from reading an input stream
105      *
106      * @param[in] istream Istream containing mlgo heuristics
107      *
108      * @return bool Signals if the reload succeeded or failed
109      */
110     bool reload_from_stream(std::istream &istream);
111 
112     /** Get the heuristic tree from tree id
113      *
114      * @param[in] id Tree id.
115      *
116      * @return HeuristicTree&
117      */
118     std::pair<bool, HeuristicTree *> get_heuristic_tree(HeuristicTree::TreeID id);
119     /** Add a heuristic tree
120      * @param t Heuristic tree to be added
121      */
122     bool add_heuristic_tree(HeuristicTree &&t);
123 
124     /** Check the validity of the heuristic tree.
125      *
126      * @param id ID of the tree to be checked
127      *
128      * @return bool
129      */
130     bool check_heuristic_tree(HeuristicTree::TreeID id);
131 
132     /** Check the overall validity of the heuristics.
133      * @return bool
134      */
135     bool check_all() const;
136 
137 private:
138     static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/
139 
140 private:
141     // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree
142     std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices;    /**< A mapping from TreeID to Index */
143     std::map<HeuristicTree::Index, HeuristicTree>         _trees;      /**< A mapping from Index to HeuristicTree */
144     std::map<HeuristicTree::TreeID, bool>                 _tree_valid; /**< Result cache of the tree validity checks */
145     bool _valid;                                                       /**< Overall validity */
146 };
147 
148 } // namespace mlgo
149 } // namespace arm_compute
150 #endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H