1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "extendrt/delegate/ascend_ge/aoe_api_tune_process.h"
18 #include <cstdio>
19 #include <iostream>
20 #include <tuple>
21 #include <vector>
22 #include <string>
23 #include <map>
24 #include "mindspore/lite/src/common/common.h"
25 #include "mindspore/lite/src/extendrt/cxx_api/dlutils.h"
26 #include "mindspore/ccsrc/utils/dlopen_macro.h"
27 #include "mindspore/ccsrc/cxx_api/acl_utils.h"
28
29 namespace mindspore {
30 namespace {
31 constexpr const char *kSubgraphTurning = "subgraph tuning";
32 constexpr const char *kOperatorTurning = "operator tuning";
33 constexpr const char *kSubgraphTurningIndex = "1";
34 constexpr const char *kOperatorTurningIndex = "2";
35
36 const std::map<std::string, std::string> kTuneModeMap = {{"1", "subgraph tuning"}, {"2", "operator tuning"}};
37 } // namespace
38
39 using AoeStatus = int32_t;
40 constexpr AoeStatus AOE_SUCCESS = 0;
41
42 class AoePlugin {
43 public:
Instance()44 static AoePlugin &Instance() {
45 static AoePlugin instance;
46 return instance;
47 }
48 AoePlugin() = default;
~AoePlugin()49 ~AoePlugin() { DLSoClose(handle_); }
LoadAoePlugin()50 bool LoadAoePlugin() {
51 if (handle_ != nullptr) {
52 return true;
53 }
54 std::string aoe_so_name = "libaoe_tuning.so";
55 std::string aoe_init_func_name = "AoeInitialize";
56 std::string aoe_finalize_func_name = "AoeFinalize";
57 std::string aoe_create_session_func_name = "AoeCreateSession";
58 std::string aoe_destroy_session_func_name = "AoeDestroySession";
59 std::string aoe_set_ge_session_func_name = "AoeSetGeSession";
60 std::string aoe_set_tuning_graph_func_name = "AoeSetTuningGraph";
61 std::string aoe_set_tuning_graph_input_func_name = "AoeSetTuningGraphInput";
62 std::string aoe_tuning_graph_func_name = "AoeTuningGraph";
63
64 auto status = DLSoOpen(aoe_so_name, "", &handle_, nullptr);
65 if (status != kSuccess) {
66 MS_LOG(ERROR) << "Dlopen " << aoe_so_name << " failed, result = " << status.ToString();
67 return false;
68 }
69 try {
70 aoe_initialize_func_ = DlsymWithCast<AoeInitializeFunc>(handle_, aoe_init_func_name.c_str());
71 aoe_finalize_func_ = DlsymWithCast<AoeFinalizeFunc>(handle_, aoe_finalize_func_name.c_str());
72 aoe_create_session_func_ = DlsymWithCast<AoeCreateSessionFunc>(handle_, aoe_create_session_func_name.c_str());
73 aoe_destroy_session_func_ = DlsymWithCast<AoeDestroySessionFunc>(handle_, aoe_destroy_session_func_name.c_str());
74 aoe_set_ge_session_func_ = DlsymWithCast<AoeSetGeSessionFunc>(handle_, aoe_set_ge_session_func_name.c_str());
75 aoe_set_tuning_graph_func_ =
76 DlsymWithCast<AoeSetTuningGraphFunc>(handle_, aoe_set_tuning_graph_func_name.c_str());
77 aoe_set_tuning_graph_input_func_ =
78 DlsymWithCast<AoeSetTuningGraphInputFunc>(handle_, aoe_set_tuning_graph_input_func_name.c_str());
79 aoe_tuning_graph_func_ = DlsymWithCast<AoeTuningGraphFunc>(handle_, aoe_tuning_graph_func_name.c_str());
80 } catch (const std::runtime_error &error) {
81 MS_LOG(ERROR) << "Failed to load symbol from " << aoe_so_name;
82 return false;
83 }
84 return true;
85 }
AoeInitialize(const std::map<std::string,std::string> & global_options)86 bool AoeInitialize(const std::map<std::string, std::string> &global_options) {
87 if (aoe_initialize_func_ == nullptr) {
88 MS_LOG(ERROR) << "aoe_initialize_func_ is nullptr";
89 return false;
90 }
91 std::map<ge::AscendString, ge::AscendString> options;
92 for (auto &item : global_options) {
93 MS_LOG(INFO) << "Aoe global option " << item.first << " = " << item.second;
94 options[ge::AscendString(item.first.c_str())] = ge::AscendString(item.second.c_str());
95 }
96 auto aoe_status = aoe_initialize_func_(options);
97 if (aoe_status != AOE_SUCCESS) {
98 MS_LOG(ERROR) << "Failed to call AoeInitialize, ret: " << aoe_status;
99 return false;
100 }
101 return true;
102 }
AoeFinalize()103 void AoeFinalize() {
104 if (aoe_finalize_func_ == nullptr) {
105 MS_LOG(ERROR) << "aoe_finalize_func_ is nullptr";
106 return;
107 }
108 aoe_finalize_func_();
109 }
AoeCreateSession(uint64_t * session_id)110 bool AoeCreateSession(uint64_t *session_id) {
111 if (session_id == nullptr) {
112 MS_LOG(ERROR) << "Input parameter session_id cannot be nullptr";
113 return false;
114 }
115 if (aoe_create_session_func_ == nullptr) {
116 MS_LOG(ERROR) << "aoe_create_session_func_ is nullptr";
117 return false;
118 }
119 auto aoe_status = aoe_create_session_func_(*session_id);
120 if (aoe_status != AOE_SUCCESS) {
121 MS_LOG(ERROR) << "Failed to call AoeCreateSession, ret: " << aoe_status;
122 return false;
123 }
124 return true;
125 }
AoeDestroySession(uint64_t session_id)126 void AoeDestroySession(uint64_t session_id) {
127 if (aoe_destroy_session_func_ == nullptr) {
128 MS_LOG(ERROR) << "aoe_destroy_session_func_ is nullptr";
129 return;
130 }
131 auto aoe_status = aoe_destroy_session_func_(session_id);
132 if (aoe_status != AOE_SUCCESS) {
133 MS_LOG(ERROR) << "Failed to call AoeDestroySession, ret: " << aoe_status;
134 return;
135 }
136 }
AoeSetGeSession(uint64_t session_id,ge::Session * ge_session)137 bool AoeSetGeSession(uint64_t session_id, ge::Session *ge_session) {
138 if (aoe_set_ge_session_func_ == nullptr) {
139 MS_LOG(ERROR) << "aoe_set_ge_session_func_ is nullptr";
140 return false;
141 }
142 auto aoe_status = aoe_set_ge_session_func_(session_id, ge_session);
143 if (aoe_status != AOE_SUCCESS) {
144 MS_LOG(ERROR) << "Failed to call AoeSetGeSession, ret: " << aoe_status;
145 return false;
146 }
147 return true;
148 }
AoeSetTuningGraph(uint64_t session_id,const ge::Graph & ge_graph)149 bool AoeSetTuningGraph(uint64_t session_id, const ge::Graph &ge_graph) {
150 if (aoe_set_tuning_graph_func_ == nullptr) {
151 MS_LOG(ERROR) << "aoe_set_tuning_graph_func_ is nullptr";
152 return false;
153 }
154 auto aoe_status = aoe_set_tuning_graph_func_(session_id, ge_graph);
155 if (aoe_status != AOE_SUCCESS) {
156 MS_LOG(ERROR) << "Failed to call AoeSetTuningGraph, ret: " << aoe_status;
157 return false;
158 }
159 return true;
160 }
AoeSetTuningGraphInput(uint64_t session_id,const std::vector<ge::Tensor> & inputs)161 bool AoeSetTuningGraphInput(uint64_t session_id, const std::vector<ge::Tensor> &inputs) {
162 if (aoe_set_tuning_graph_input_func_ == nullptr) {
163 MS_LOG(ERROR) << "aoe_set_tuning_graph_input_func_ is nullptr";
164 return false;
165 }
166 auto aoe_status = aoe_set_tuning_graph_input_func_(session_id, inputs);
167 if (aoe_status != AOE_SUCCESS) {
168 MS_LOG(ERROR) << "Failed to call AoeSetTuningGraphInput, ret: " << aoe_status;
169 return false;
170 }
171 return true;
172 }
AoeTuningGraph(uint64_t session_id,const std::map<std::string,std::string> & tuning_options)173 bool AoeTuningGraph(uint64_t session_id, const std::map<std::string, std::string> &tuning_options) {
174 if (aoe_tuning_graph_func_ == nullptr) {
175 MS_LOG(ERROR) << "aoe_tuning_graph_func_ is nullptr";
176 return false;
177 }
178 std::map<ge::AscendString, ge::AscendString> options;
179 for (auto &item : tuning_options) {
180 MS_LOG(INFO) << "Aoe tuning option " << item.first << " = " << item.second;
181 options[ge::AscendString(item.first.c_str())] = ge::AscendString(item.second.c_str());
182 }
183 auto aoe_status = aoe_tuning_graph_func_(session_id, options);
184 if (aoe_status != AOE_SUCCESS) {
185 MS_LOG(ERROR) << "Failed to call AoeTuningGraph, ret: " << aoe_status;
186 return false;
187 }
188 return true;
189 }
190
191 private:
192 using AoeInitializeFunc = AoeStatus (*)(const std::map<ge::AscendString, ge::AscendString> &);
193 using AoeFinalizeFunc = AoeStatus (*)();
194 using AoeCreateSessionFunc = AoeStatus (*)(uint64_t &);
195 using AoeDestroySessionFunc = AoeStatus (*)(uint64_t);
196 using AoeSetGeSessionFunc = AoeStatus (*)(uint64_t, ge::Session *);
197 using AoeSetTuningGraphFunc = AoeStatus (*)(uint64_t, const ge::Graph &);
198 using AoeSetTuningGraphInputFunc = AoeStatus (*)(uint64_t, const std::vector<ge::Tensor> &);
199 using AoeTuningGraphFunc = AoeStatus (*)(uint64_t, const std::map<ge::AscendString, ge::AscendString> &);
200
201 AoeInitializeFunc aoe_initialize_func_ = nullptr;
202 AoeFinalizeFunc aoe_finalize_func_ = nullptr;
203 AoeCreateSessionFunc aoe_create_session_func_ = nullptr;
204 AoeDestroySessionFunc aoe_destroy_session_func_ = nullptr;
205 AoeSetGeSessionFunc aoe_set_ge_session_func_ = nullptr;
206 AoeSetTuningGraphFunc aoe_set_tuning_graph_func_ = nullptr;
207 AoeSetTuningGraphInputFunc aoe_set_tuning_graph_input_func_ = nullptr;
208 AoeTuningGraphFunc aoe_tuning_graph_func_ = nullptr;
209
210 void *handle_ = nullptr;
211 };
212
ExecuteAoe(const std::shared_ptr<ge::Session> & session,const transform::DfGraphPtr & graph,const std::vector<ge::Tensor> & inputs,const std::vector<std::string> & job_types,const std::map<std::string,std::string> & global_options,const std::map<std::string,std::string> & tuning_options)213 Status AoeApiTuning::ExecuteAoe(const std::shared_ptr<ge::Session> &session, const transform::DfGraphPtr &graph,
214 const std::vector<ge::Tensor> &inputs, const std::vector<std::string> &job_types,
215 const std::map<std::string, std::string> &global_options,
216 const std::map<std::string, std::string> &tuning_options) {
217 MS_LOG(INFO) << "Start to aoe.";
218 try {
219 auto &aoe_instance = AoePlugin::Instance();
220 if (!aoe_instance.LoadAoePlugin()) {
221 return kLiteError;
222 }
223 for (auto &job_type : job_types) {
224 std::cout << "Start to " << kTuneModeMap.at(job_type) << std::endl;
225 std::map<std::string, std::string> global_options_new = global_options;
226 global_options_new["job_type"] = job_type;
227 if (!aoe_instance.AoeInitialize(global_options_new)) {
228 return kLiteError;
229 }
230 uint64_t session_id = 0;
231 if (!aoe_instance.AoeCreateSession(&session_id)) {
232 aoe_instance.AoeFinalize();
233 return kLiteError;
234 }
235 if (session && !aoe_instance.AoeSetGeSession(session_id, session.get())) {
236 aoe_instance.AoeDestroySession(session_id);
237 aoe_instance.AoeFinalize();
238 return kLiteError;
239 }
240 if (!aoe_instance.AoeSetTuningGraph(session_id, *graph)) {
241 aoe_instance.AoeDestroySession(session_id);
242 aoe_instance.AoeFinalize();
243 return kLiteError;
244 }
245 if (!inputs.empty() && !aoe_instance.AoeSetTuningGraphInput(session_id, inputs)) {
246 aoe_instance.AoeDestroySession(session_id);
247 aoe_instance.AoeFinalize();
248 return kLiteError;
249 }
250 if (!aoe_instance.AoeTuningGraph(session_id, tuning_options)) {
251 aoe_instance.AoeDestroySession(session_id);
252 aoe_instance.AoeFinalize();
253 return kLiteError;
254 }
255 aoe_instance.AoeDestroySession(session_id);
256 aoe_instance.AoeFinalize();
257 std::cout << "End " << kTuneModeMap.at(job_type) << std::endl;
258 }
259 return kSuccess;
260 } catch (const std::exception &e) {
261 MS_LOG(ERROR) << "Execute aoe failed: " << e.what();
262 } catch (...) {
263 MS_LOG(ERROR) << "Execute aoe failed.";
264 }
265 return kMCFailed;
266 }
267
GetAscendDeviceInfo(const std::shared_ptr<Context> & context)268 static std::shared_ptr<AscendDeviceInfo> GetAscendDeviceInfo(const std::shared_ptr<Context> &context) {
269 if (context == nullptr) {
270 return nullptr;
271 }
272 auto &device_infos = context->MutableDeviceInfo();
273 if (device_infos.size() != 1 || device_infos[0] == nullptr) {
274 return nullptr;
275 }
276 auto ascend_info = device_infos[0]->Cast<AscendDeviceInfo>();
277 if (ascend_info == nullptr) {
278 return nullptr;
279 }
280 return ascend_info;
281 }
282
SetOption(std::map<std::string,std::string> * aoe_options,const std::string & key,const std::map<std::string,std::string> & config_options,std::string get_key="")283 static void SetOption(std::map<std::string, std::string> *aoe_options, const std::string &key,
284 const std::map<std::string, std::string> &config_options, std::string get_key = "") {
285 if (get_key.empty()) {
286 get_key = key;
287 }
288 auto it = config_options.find(get_key);
289 if (it == config_options.end()) {
290 return;
291 }
292 (*aoe_options)[key] = it->second;
293 }
294
SetOption(std::map<std::string,std::string> * aoe_options,const std::string & key,const std::function<std::string ()> & func)295 static void SetOption(std::map<std::string, std::string> *aoe_options, const std::string &key,
296 const std::function<std::string()> &func) {
297 auto option = func();
298 if (!option.empty()) {
299 (*aoe_options)[key] = option;
300 }
301 }
302
GetAoeGlobalOptions(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)303 std::map<std::string, std::string> AoeApiTuning::GetAoeGlobalOptions(const std::shared_ptr<Context> &context,
304 const ConfigInfos &config_infos) {
305 // framework, device, precision_mode
306 std::map<std::string, std::string> aoe_options;
307 aoe_options["framework"] = "1";
308 // get options from [acl_option_cfg_param]
309 auto section_it = config_infos.find(lite::kAclOptionParam);
310 if (section_it != config_infos.end()) {
311 auto &options = section_it->second;
312 auto option_it = options.find("precision_mode");
313 if (option_it != options.end()) {
314 aoe_options["precision_mode"] = TransforPrecisionToAcl(option_it->second);
315 }
316 }
317 // get options from AscendDeviceInfo: may parse from [ascend_context] & [acl_option_cfg_param]
318 auto ascend_info = GetAscendDeviceInfo(context);
319 if (ascend_info == nullptr) {
320 MS_LOG(WARNING) << "Failed to get ascend device info from context";
321 return {};
322 }
323 aoe_options["device"] = std::to_string(ascend_info->GetDeviceID());
324 auto precision_mode = ascend_info->GetPrecisionMode();
325 if (!precision_mode.empty()) {
326 aoe_options["precision_mode"] = TransforPrecisionToAcl(precision_mode);
327 }
328 // get options from [ge_session_options]
329 section_it = config_infos.find(lite::kGeSessionOptionsSection);
330 if (section_it != config_infos.end()) {
331 auto &options = section_it->second;
332 SetOption(&aoe_options, "precision_mode", options, "ge.exec.precision_mode");
333 }
334 // get options from [ge_graph_options]
335 section_it = config_infos.find(lite::kGeGraphOptionsSection);
336 if (section_it != config_infos.end()) {
337 auto &options = section_it->second;
338 SetOption(&aoe_options, "precision_mode", options, "ge.exec.precision_mode");
339 }
340 // get options from [aoe_global_options]
341 section_it = config_infos.find(lite::kAoeGlobalOptionsSection);
342 if (section_it != config_infos.end()) {
343 for (auto &option_item : section_it->second) {
344 aoe_options[option_item.first] = option_item.second;
345 MS_LOG(INFO) << "Update global option " << option_item.first << " to " << option_item.second;
346 }
347 }
348 return aoe_options;
349 }
350
GetAoeTuningOptions(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)351 std::map<std::string, std::string> AoeApiTuning::GetAoeTuningOptions(const std::shared_ptr<Context> &context,
352 const ConfigInfos &config_infos) {
353 // input_shape,dynamic_batch_size, dynamic_image_size, dynamic_dims & input_format
354 std::map<std::string, std::string> aoe_options;
355 // get options from [acl_option_cfg_param]
356 auto section_it = config_infos.find(lite::kAclOptionParam);
357 if (section_it != config_infos.end()) {
358 auto &options = section_it->second;
359 SetOption(&aoe_options, "input_shape", options);
360 SetOption(&aoe_options, "input_format", options);
361 SetOption(&aoe_options, "dynamic_batch_size", options);
362 SetOption(&aoe_options, "dynamic_image_size", options);
363 SetOption(&aoe_options, "dynamic_dims", options);
364 }
365 // get options for AscendDeviceInfo: may parse from [ascend_context] & [acl_option_cfg_param]
366 auto ascend_info = GetAscendDeviceInfo(context);
367 if (ascend_info == nullptr) {
368 MS_LOG(WARNING) << "Failed to get ascend device info from context";
369 return {};
370 }
371 SetOption(&aoe_options, "input_format", [ascend_info]() { return ascend_info->GetInputFormat(); });
372 SetOption(&aoe_options, "input_shape", [ascend_info]() { return ascend_info->GetInputShape(); });
373 SetOption(&aoe_options, "dynamic_batch_size", [ascend_info]() { return ascend_info->GetDynamicBatchSize(); });
374 SetOption(&aoe_options, "dynamic_image_size", [ascend_info]() { return ascend_info->GetDynamicImageSize(); });
375
376 // get options from [ge_graph_options]
377 section_it = config_infos.find(lite::kGeGraphOptionsSection);
378 if (section_it != config_infos.end()) {
379 auto &options = section_it->second;
380 SetOption(&aoe_options, "input_shape", options, "ge.inputShape");
381 SetOption(&aoe_options, "dynamic_dims", options, "ge.dynamicDims");
382 }
383 // get options from [aoe_tuning_options]
384 section_it = config_infos.find(lite::kAoeTuningOptionsSection);
385 if (section_it != config_infos.end()) {
386 for (auto &option_item : section_it->second) {
387 aoe_options[option_item.first] = option_item.second;
388 MS_LOG(INFO) << "Update tuning option " << option_item.first << " to " << option_item.second;
389 }
390 }
391 if ((aoe_options.find("dynamic_batch_size") != aoe_options.end() ||
392 aoe_options.find("dynamic_dims") != aoe_options.end()) &&
393 aoe_options.find("input_format") == aoe_options.end()) {
394 aoe_options["input_format"] = "ND";
395 }
396 return aoe_options;
397 }
398
GetAoeJobType(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)399 std::vector<std::string> AoeApiTuning::GetAoeJobType(const std::shared_ptr<Context> &context,
400 const ConfigInfos &config_infos) {
401 std::vector<std::string> job_types;
402 // get options from [acl_option_cfg_param]
403 auto section_it = config_infos.find(lite::kAclOptionParam);
404 if (section_it != config_infos.end()) {
405 auto &options = section_it->second;
406 auto option_it = options.find("aoe_mode");
407 if (option_it != options.end()) {
408 auto &option = option_it->second;
409 if (option.find(kOperatorTurning) != std::string::npos) {
410 job_types.push_back(kOperatorTurningIndex);
411 }
412 if (option.find(kSubgraphTurning) != std::string::npos) {
413 job_types.push_back(kSubgraphTurningIndex);
414 }
415 }
416 }
417 // get options from [ascend_context]
418 section_it = config_infos.find(lite::kAscendContextSection);
419 if (section_it != config_infos.end()) {
420 auto &options = section_it->second;
421 auto option_it = options.find("aoe_mode");
422 if (option_it != options.end()) {
423 job_types.clear();
424 auto &option = option_it->second;
425 if (option.find(kOperatorTurning) != std::string::npos) {
426 job_types.push_back(kOperatorTurningIndex);
427 }
428 if (option.find(kSubgraphTurning) != std::string::npos) {
429 job_types.push_back(kSubgraphTurningIndex);
430 }
431 }
432 }
433 if (job_types.size() > 1) {
434 MS_LOG(ERROR) << "Config aoe_mode should only be " << kOperatorTurning << " or " << kSubgraphTurning
435 << " when provider is ge";
436 return {};
437 }
438 // get options from [aoe_global_options]
439 section_it = config_infos.find(lite::kAoeGlobalOptionsSection);
440 if (section_it != config_infos.end()) {
441 auto &options = section_it->second;
442 auto option_it = options.find("job_type");
443 if (option_it != options.end()) {
444 auto option = option_it->second;
445 if (option != kSubgraphTurningIndex && option != kOperatorTurningIndex) {
446 MS_LOG(ERROR) << "Config job_type should only be " << kOperatorTurningIndex << " or " << kSubgraphTurningIndex
447 << " when provider is ge";
448 return {};
449 }
450 job_types.clear();
451 job_types.push_back(option_it->second);
452 }
453 }
454 if (job_types.empty()) {
455 MS_LOG(ERROR) << "Option aoe_mode or job_type is not set, option aoe_mode should be in section "
456 << lite::kAclOptionParam << " or " << lite::kAscendContextSection
457 << ", job_type should be in section " << lite::kAoeGlobalOptionsSection;
458 }
459 return job_types;
460 }
461
AoeTurningGraph(const std::shared_ptr<ge::Session> & session,const transform::DfGraphPtr & graph,const std::vector<ge::Tensor> & inputs,const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)462 Status AoeApiTuning::AoeTurningGraph(const std::shared_ptr<ge::Session> &session, const transform::DfGraphPtr &graph,
463 const std::vector<ge::Tensor> &inputs, const std::shared_ptr<Context> &context,
464 const ConfigInfos &config_infos) {
465 auto global_options = GetAoeGlobalOptions(context, config_infos);
466 auto tuning_options = GetAoeTuningOptions(context, config_infos);
467 auto job_types = GetAoeJobType(context, config_infos);
468 if (job_types.empty()) {
469 return kLiteError;
470 }
471 if (ExecuteAoe(session, graph, inputs, job_types, global_options, tuning_options) != kSuccess) {
472 MS_LOG(ERROR) << "Execute aoe failed";
473 return kLiteError;
474 }
475 return kSuccess;
476 }
477 } // namespace mindspore
478