• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/runtime/CL/mlgo/MLGOHeuristics.h"
25 
26 #include "arm_compute/core/Log.h"
27 #include "src/runtime/CL/mlgo/MLGOParser.h"
28 #include "src/runtime/CL/mlgo/Utils.h"
29 
30 #include <fstream>
31 
32 namespace arm_compute
33 {
34 namespace mlgo
35 {
operator ==(const GEMMConfigNative & lhs,const GEMMConfigNative & rhs)36 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
37 {
38     return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
39 }
operator ==(const GEMMConfigReshapedOnlyRHS & lhs,const GEMMConfigReshapedOnlyRHS & rhs)40 bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
41 {
42     return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
43                                                                                                                             rhs.export_cl_image);
44 }
operator ==(const GEMMConfigReshaped & lhs,const GEMMConfigReshaped & rhs)45 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
46 {
47     return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
48             rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
49 }
50 
51 constexpr size_t MLGOHeuristics::_max_num_trees;
52 
MLGOHeuristics()53 MLGOHeuristics::MLGOHeuristics()
54     : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
55 {
56 }
57 
query_gemm_type(const Query & query) const58 std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) const
59 {
60     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
61     const auto invalid = GEMMType::RESHAPED;
62     if(!_valid)
63     {
64         ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
65         return { false, invalid };
66     }
67     auto      index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
68     GEMMShape shape_query{ query.m, query.n, query.k, query.b };
69     if(_trees.find(index) == _trees.end())
70     {
71         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
72         return { false, invalid };
73     }
74     return _trees.at(index).query<GEMMType>(shape_query);
75 }
query_gemm_config_native(const Query & query) const76 std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
77 {
78     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
79     const auto invalid = GEMMConfigNative{};
80     if(!_valid)
81     {
82         ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
83         return { false, invalid };
84     }
85     auto      index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
86     GEMMShape shape_query{ query.m, query.n, query.k, query.b };
87     if(_trees.find(index) == _trees.end())
88     {
89         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
90         return { false, invalid };
91     }
92     return _trees.at(index).query<GEMMConfigNative>(shape_query);
93 }
query_gemm_config_reshaped_only_rhs(const Query & query) const94 std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
95 {
96     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
97     const auto invalid = GEMMConfigReshapedOnlyRHS{};
98     if(!_valid)
99     {
100         ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
101         return { false, invalid };
102     }
103     auto      index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
104     GEMMShape shape_query{ query.m, query.n, query.k, query.b };
105     if(_trees.find(index) == _trees.end())
106     {
107         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
108         return { false, invalid };
109     }
110     return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
111 }
query_gemm_config_reshaped(const Query & query) const112 std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
113 {
114     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
115     const auto invalid = GEMMConfigReshaped{};
116     if(!_valid)
117     {
118         ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
119         return { false, invalid };
120     }
121     auto      index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
122     GEMMShape shape_query{ query.m, query.n, query.k, query.b };
123     if(_trees.find(index) == _trees.end())
124     {
125         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
126         return { false, invalid };
127     }
128     return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
129 }
130 
check_heuristic_tree(HeuristicTree::TreeID id)131 bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
132 {
133     bool           status;
134     HeuristicTree *tree{ nullptr };
135     std::tie(status, tree) = get_heuristic_tree(id);
136     if(!status)
137     {
138         return status;
139     }
140     status = tree->check();
141     if(!status)
142     {
143         return status;
144     }
145     _tree_valid[id] = true;
146     return true;
147 }
148 
check_all() const149 bool MLGOHeuristics::check_all() const
150 {
151     // Tree validities are already checked and cached.
152     bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
153     {
154         return !v.second;
155     })
156     == _tree_valid.end();
157     if(!all_trees_are_checked)
158     {
159         ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
160         return false;
161     }
162 
163     // Other top level checks...
164 
165     return true;
166 }
167 
get_heuristic_tree(HeuristicTree::TreeID id)168 std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
169 {
170     if(_indices.find(id) == _indices.end())
171     {
172         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
173         return std::make_pair(false, nullptr);
174     }
175     const auto index = _indices[id];
176 
177     if(_trees.find(index) == _trees.end())
178     {
179         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
180         return std::make_pair(false, nullptr);
181     }
182     auto &t = _trees[index];
183 
184     return std::make_pair(true, &t);
185 }
186 
add_heuristic_tree(HeuristicTree && t)187 bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
188 {
189     if(_indices.size() >= _max_num_trees)
190     {
191         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
192         return false;
193     }
194     // PRE: correctness of t is guaranteed by the tree construction process
195     // Ensure unique id
196     const auto id = t.id();
197     if(_indices.find(id) != _indices.end())
198     {
199         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
200         return false;
201     }
202 
203     // Ensure unique index
204     const auto index = t.index();
205     if(_trees.find(index) != _trees.end())
206     {
207         ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
208         return false;
209     }
210 
211     _indices[id]    = index;
212     _trees[index]   = std::move(t);
213     _tree_valid[id] = false;
214     return true;
215 }
216 
reload_from_file(const std::string & filename)217 bool MLGOHeuristics::reload_from_file(const std::string &filename)
218 {
219     std::ifstream fs;
220     fs.exceptions(std::ifstream::badbit);
221     fs.open(filename, std::ios::in);
222     if(!fs.is_open())
223     {
224         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
225         return _valid = false;
226     }
227     return reload_from_stream(fs);
228 }
229 
reload_from_stream(std::istream & in)230 bool MLGOHeuristics::reload_from_stream(std::istream &in)
231 {
232     auto parsed = parser::parse_mlgo(in);
233     if(!parsed.first)
234     {
235         ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
236         return _valid = false;
237     }
238     *this = std::move(parsed.second);
239     ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
240     return _valid = true;
241 }
242 
243 } // namespace mlgo
244 } // namespace arm_compute