1 /*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/runtime/CL/mlgo/MLGOHeuristics.h"
25
26 #include "arm_compute/core/Log.h"
27 #include "src/runtime/CL/mlgo/MLGOParser.h"
28 #include "src/runtime/CL/mlgo/Utils.h"
29
30 #include <fstream>
31
32 namespace arm_compute
33 {
34 namespace mlgo
35 {
operator ==(const GEMMConfigNative & lhs,const GEMMConfigNative & rhs)36 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
37 {
38 return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
39 }
operator ==(const GEMMConfigReshapedOnlyRHS & lhs,const GEMMConfigReshapedOnlyRHS & rhs)40 bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
41 {
42 return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
43 rhs.export_cl_image);
44 }
operator ==(const GEMMConfigReshaped & lhs,const GEMMConfigReshaped & rhs)45 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
46 {
47 return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
48 rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
49 }
50
51 constexpr size_t MLGOHeuristics::_max_num_trees;
52
MLGOHeuristics()53 MLGOHeuristics::MLGOHeuristics()
54 : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
55 {
56 }
57
query_gemm_type(const Query & query) const58 std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) const
59 {
60 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
61 const auto invalid = GEMMType::RESHAPED;
62 if(!_valid)
63 {
64 ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
65 return { false, invalid };
66 }
67 auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
68 GEMMShape shape_query{ query.m, query.n, query.k, query.b };
69 if(_trees.find(index) == _trees.end())
70 {
71 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
72 return { false, invalid };
73 }
74 return _trees.at(index).query<GEMMType>(shape_query);
75 }
query_gemm_config_native(const Query & query) const76 std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
77 {
78 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
79 const auto invalid = GEMMConfigNative{};
80 if(!_valid)
81 {
82 ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
83 return { false, invalid };
84 }
85 auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
86 GEMMShape shape_query{ query.m, query.n, query.k, query.b };
87 if(_trees.find(index) == _trees.end())
88 {
89 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
90 return { false, invalid };
91 }
92 return _trees.at(index).query<GEMMConfigNative>(shape_query);
93 }
query_gemm_config_reshaped_only_rhs(const Query & query) const94 std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
95 {
96 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
97 const auto invalid = GEMMConfigReshapedOnlyRHS{};
98 if(!_valid)
99 {
100 ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
101 return { false, invalid };
102 }
103 auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
104 GEMMShape shape_query{ query.m, query.n, query.k, query.b };
105 if(_trees.find(index) == _trees.end())
106 {
107 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
108 return { false, invalid };
109 }
110 return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
111 }
query_gemm_config_reshaped(const Query & query) const112 std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
113 {
114 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
115 const auto invalid = GEMMConfigReshaped{};
116 if(!_valid)
117 {
118 ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
119 return { false, invalid };
120 }
121 auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
122 GEMMShape shape_query{ query.m, query.n, query.k, query.b };
123 if(_trees.find(index) == _trees.end())
124 {
125 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
126 return { false, invalid };
127 }
128 return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
129 }
130
check_heuristic_tree(HeuristicTree::TreeID id)131 bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
132 {
133 bool status;
134 HeuristicTree *tree{ nullptr };
135 std::tie(status, tree) = get_heuristic_tree(id);
136 if(!status)
137 {
138 return status;
139 }
140 status = tree->check();
141 if(!status)
142 {
143 return status;
144 }
145 _tree_valid[id] = true;
146 return true;
147 }
148
check_all() const149 bool MLGOHeuristics::check_all() const
150 {
151 // Tree validities are already checked and cached.
152 bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
153 {
154 return !v.second;
155 })
156 == _tree_valid.end();
157 if(!all_trees_are_checked)
158 {
159 ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
160 return false;
161 }
162
163 // Other top level checks...
164
165 return true;
166 }
167
get_heuristic_tree(HeuristicTree::TreeID id)168 std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
169 {
170 if(_indices.find(id) == _indices.end())
171 {
172 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
173 return std::make_pair(false, nullptr);
174 }
175 const auto index = _indices[id];
176
177 if(_trees.find(index) == _trees.end())
178 {
179 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
180 return std::make_pair(false, nullptr);
181 }
182 auto &t = _trees[index];
183
184 return std::make_pair(true, &t);
185 }
186
add_heuristic_tree(HeuristicTree && t)187 bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
188 {
189 if(_indices.size() >= _max_num_trees)
190 {
191 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
192 return false;
193 }
194 // PRE: correctness of t is guaranteed by the tree construction process
195 // Ensure unique id
196 const auto id = t.id();
197 if(_indices.find(id) != _indices.end())
198 {
199 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
200 return false;
201 }
202
203 // Ensure unique index
204 const auto index = t.index();
205 if(_trees.find(index) != _trees.end())
206 {
207 ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
208 return false;
209 }
210
211 _indices[id] = index;
212 _trees[index] = std::move(t);
213 _tree_valid[id] = false;
214 return true;
215 }
216
reload_from_file(const std::string & filename)217 bool MLGOHeuristics::reload_from_file(const std::string &filename)
218 {
219 std::ifstream fs;
220 fs.exceptions(std::ifstream::badbit);
221 fs.open(filename, std::ios::in);
222 if(!fs.is_open())
223 {
224 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
225 return _valid = false;
226 }
227 return reload_from_stream(fs);
228 }
229
reload_from_stream(std::istream & in)230 bool MLGOHeuristics::reload_from_stream(std::istream &in)
231 {
232 auto parsed = parser::parse_mlgo(in);
233 if(!parsed.first)
234 {
235 ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
236 return _valid = false;
237 }
238 *this = std::move(parsed.second);
239 ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
240 return _valid = true;
241 }
242
243 } // namespace mlgo
244 } // namespace arm_compute