• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "extendrt/delegate/ascend_ge/aoe_api_tune_process.h"
18 #include <cstdio>
19 #include <iostream>
20 #include <tuple>
21 #include <vector>
22 #include <string>
23 #include <map>
24 #include "mindspore/lite/src/common/common.h"
25 #include "mindspore/lite/src/extendrt/cxx_api/dlutils.h"
26 #include "mindspore/ccsrc/utils/dlopen_macro.h"
27 #include "mindspore/ccsrc/cxx_api/acl_utils.h"
28 
29 namespace mindspore {
30 namespace {
31 constexpr const char *kSubgraphTurning = "subgraph tuning";
32 constexpr const char *kOperatorTurning = "operator tuning";
33 constexpr const char *kSubgraphTurningIndex = "1";
34 constexpr const char *kOperatorTurningIndex = "2";
35 
36 const std::map<std::string, std::string> kTuneModeMap = {{"1", "subgraph tuning"}, {"2", "operator tuning"}};
37 }  // namespace
38 
39 using AoeStatus = int32_t;
40 constexpr AoeStatus AOE_SUCCESS = 0;
41 
42 class AoePlugin {
43  public:
Instance()44   static AoePlugin &Instance() {
45     static AoePlugin instance;
46     return instance;
47   }
48   AoePlugin() = default;
~AoePlugin()49   ~AoePlugin() { DLSoClose(handle_); }
LoadAoePlugin()50   bool LoadAoePlugin() {
51     if (handle_ != nullptr) {
52       return true;
53     }
54     std::string aoe_so_name = "libaoe_tuning.so";
55     std::string aoe_init_func_name = "AoeInitialize";
56     std::string aoe_finalize_func_name = "AoeFinalize";
57     std::string aoe_create_session_func_name = "AoeCreateSession";
58     std::string aoe_destroy_session_func_name = "AoeDestroySession";
59     std::string aoe_set_ge_session_func_name = "AoeSetGeSession";
60     std::string aoe_set_tuning_graph_func_name = "AoeSetTuningGraph";
61     std::string aoe_set_tuning_graph_input_func_name = "AoeSetTuningGraphInput";
62     std::string aoe_tuning_graph_func_name = "AoeTuningGraph";
63 
64     auto status = DLSoOpen(aoe_so_name, "", &handle_, nullptr);
65     if (status != kSuccess) {
66       MS_LOG(ERROR) << "Dlopen " << aoe_so_name << " failed, result = " << status.ToString();
67       return false;
68     }
69     try {
70       aoe_initialize_func_ = DlsymWithCast<AoeInitializeFunc>(handle_, aoe_init_func_name.c_str());
71       aoe_finalize_func_ = DlsymWithCast<AoeFinalizeFunc>(handle_, aoe_finalize_func_name.c_str());
72       aoe_create_session_func_ = DlsymWithCast<AoeCreateSessionFunc>(handle_, aoe_create_session_func_name.c_str());
73       aoe_destroy_session_func_ = DlsymWithCast<AoeDestroySessionFunc>(handle_, aoe_destroy_session_func_name.c_str());
74       aoe_set_ge_session_func_ = DlsymWithCast<AoeSetGeSessionFunc>(handle_, aoe_set_ge_session_func_name.c_str());
75       aoe_set_tuning_graph_func_ =
76         DlsymWithCast<AoeSetTuningGraphFunc>(handle_, aoe_set_tuning_graph_func_name.c_str());
77       aoe_set_tuning_graph_input_func_ =
78         DlsymWithCast<AoeSetTuningGraphInputFunc>(handle_, aoe_set_tuning_graph_input_func_name.c_str());
79       aoe_tuning_graph_func_ = DlsymWithCast<AoeTuningGraphFunc>(handle_, aoe_tuning_graph_func_name.c_str());
80     } catch (const std::runtime_error &error) {
81       MS_LOG(ERROR) << "Failed to load symbol from " << aoe_so_name;
82       return false;
83     }
84     return true;
85   }
AoeInitialize(const std::map<std::string,std::string> & global_options)86   bool AoeInitialize(const std::map<std::string, std::string> &global_options) {
87     if (aoe_initialize_func_ == nullptr) {
88       MS_LOG(ERROR) << "aoe_initialize_func_ is nullptr";
89       return false;
90     }
91     std::map<ge::AscendString, ge::AscendString> options;
92     for (auto &item : global_options) {
93       MS_LOG(INFO) << "Aoe global option " << item.first << " = " << item.second;
94       options[ge::AscendString(item.first.c_str())] = ge::AscendString(item.second.c_str());
95     }
96     auto aoe_status = aoe_initialize_func_(options);
97     if (aoe_status != AOE_SUCCESS) {
98       MS_LOG(ERROR) << "Failed to call AoeInitialize, ret: " << aoe_status;
99       return false;
100     }
101     return true;
102   }
AoeFinalize()103   void AoeFinalize() {
104     if (aoe_finalize_func_ == nullptr) {
105       MS_LOG(ERROR) << "aoe_finalize_func_ is nullptr";
106       return;
107     }
108     aoe_finalize_func_();
109   }
AoeCreateSession(uint64_t * session_id)110   bool AoeCreateSession(uint64_t *session_id) {
111     if (session_id == nullptr) {
112       MS_LOG(ERROR) << "Input parameter session_id cannot be nullptr";
113       return false;
114     }
115     if (aoe_create_session_func_ == nullptr) {
116       MS_LOG(ERROR) << "aoe_create_session_func_ is nullptr";
117       return false;
118     }
119     auto aoe_status = aoe_create_session_func_(*session_id);
120     if (aoe_status != AOE_SUCCESS) {
121       MS_LOG(ERROR) << "Failed to call AoeCreateSession, ret: " << aoe_status;
122       return false;
123     }
124     return true;
125   }
AoeDestroySession(uint64_t session_id)126   void AoeDestroySession(uint64_t session_id) {
127     if (aoe_destroy_session_func_ == nullptr) {
128       MS_LOG(ERROR) << "aoe_destroy_session_func_ is nullptr";
129       return;
130     }
131     auto aoe_status = aoe_destroy_session_func_(session_id);
132     if (aoe_status != AOE_SUCCESS) {
133       MS_LOG(ERROR) << "Failed to call AoeDestroySession, ret: " << aoe_status;
134       return;
135     }
136   }
AoeSetGeSession(uint64_t session_id,ge::Session * ge_session)137   bool AoeSetGeSession(uint64_t session_id, ge::Session *ge_session) {
138     if (aoe_set_ge_session_func_ == nullptr) {
139       MS_LOG(ERROR) << "aoe_set_ge_session_func_ is nullptr";
140       return false;
141     }
142     auto aoe_status = aoe_set_ge_session_func_(session_id, ge_session);
143     if (aoe_status != AOE_SUCCESS) {
144       MS_LOG(ERROR) << "Failed to call AoeSetGeSession, ret: " << aoe_status;
145       return false;
146     }
147     return true;
148   }
AoeSetTuningGraph(uint64_t session_id,const ge::Graph & ge_graph)149   bool AoeSetTuningGraph(uint64_t session_id, const ge::Graph &ge_graph) {
150     if (aoe_set_tuning_graph_func_ == nullptr) {
151       MS_LOG(ERROR) << "aoe_set_tuning_graph_func_ is nullptr";
152       return false;
153     }
154     auto aoe_status = aoe_set_tuning_graph_func_(session_id, ge_graph);
155     if (aoe_status != AOE_SUCCESS) {
156       MS_LOG(ERROR) << "Failed to call AoeSetTuningGraph, ret: " << aoe_status;
157       return false;
158     }
159     return true;
160   }
AoeSetTuningGraphInput(uint64_t session_id,const std::vector<ge::Tensor> & inputs)161   bool AoeSetTuningGraphInput(uint64_t session_id, const std::vector<ge::Tensor> &inputs) {
162     if (aoe_set_tuning_graph_input_func_ == nullptr) {
163       MS_LOG(ERROR) << "aoe_set_tuning_graph_input_func_ is nullptr";
164       return false;
165     }
166     auto aoe_status = aoe_set_tuning_graph_input_func_(session_id, inputs);
167     if (aoe_status != AOE_SUCCESS) {
168       MS_LOG(ERROR) << "Failed to call AoeSetTuningGraphInput, ret: " << aoe_status;
169       return false;
170     }
171     return true;
172   }
AoeTuningGraph(uint64_t session_id,const std::map<std::string,std::string> & tuning_options)173   bool AoeTuningGraph(uint64_t session_id, const std::map<std::string, std::string> &tuning_options) {
174     if (aoe_tuning_graph_func_ == nullptr) {
175       MS_LOG(ERROR) << "aoe_tuning_graph_func_ is nullptr";
176       return false;
177     }
178     std::map<ge::AscendString, ge::AscendString> options;
179     for (auto &item : tuning_options) {
180       MS_LOG(INFO) << "Aoe tuning option " << item.first << " = " << item.second;
181       options[ge::AscendString(item.first.c_str())] = ge::AscendString(item.second.c_str());
182     }
183     auto aoe_status = aoe_tuning_graph_func_(session_id, options);
184     if (aoe_status != AOE_SUCCESS) {
185       MS_LOG(ERROR) << "Failed to call AoeTuningGraph, ret: " << aoe_status;
186       return false;
187     }
188     return true;
189   }
190 
191  private:
192   using AoeInitializeFunc = AoeStatus (*)(const std::map<ge::AscendString, ge::AscendString> &);
193   using AoeFinalizeFunc = AoeStatus (*)();
194   using AoeCreateSessionFunc = AoeStatus (*)(uint64_t &);
195   using AoeDestroySessionFunc = AoeStatus (*)(uint64_t);
196   using AoeSetGeSessionFunc = AoeStatus (*)(uint64_t, ge::Session *);
197   using AoeSetTuningGraphFunc = AoeStatus (*)(uint64_t, const ge::Graph &);
198   using AoeSetTuningGraphInputFunc = AoeStatus (*)(uint64_t, const std::vector<ge::Tensor> &);
199   using AoeTuningGraphFunc = AoeStatus (*)(uint64_t, const std::map<ge::AscendString, ge::AscendString> &);
200 
201   AoeInitializeFunc aoe_initialize_func_ = nullptr;
202   AoeFinalizeFunc aoe_finalize_func_ = nullptr;
203   AoeCreateSessionFunc aoe_create_session_func_ = nullptr;
204   AoeDestroySessionFunc aoe_destroy_session_func_ = nullptr;
205   AoeSetGeSessionFunc aoe_set_ge_session_func_ = nullptr;
206   AoeSetTuningGraphFunc aoe_set_tuning_graph_func_ = nullptr;
207   AoeSetTuningGraphInputFunc aoe_set_tuning_graph_input_func_ = nullptr;
208   AoeTuningGraphFunc aoe_tuning_graph_func_ = nullptr;
209 
210   void *handle_ = nullptr;
211 };
212 
ExecuteAoe(const std::shared_ptr<ge::Session> & session,const transform::DfGraphPtr & graph,const std::vector<ge::Tensor> & inputs,const std::vector<std::string> & job_types,const std::map<std::string,std::string> & global_options,const std::map<std::string,std::string> & tuning_options)213 Status AoeApiTuning::ExecuteAoe(const std::shared_ptr<ge::Session> &session, const transform::DfGraphPtr &graph,
214                                 const std::vector<ge::Tensor> &inputs, const std::vector<std::string> &job_types,
215                                 const std::map<std::string, std::string> &global_options,
216                                 const std::map<std::string, std::string> &tuning_options) {
217   MS_LOG(INFO) << "Start to aoe.";
218   try {
219     auto &aoe_instance = AoePlugin::Instance();
220     if (!aoe_instance.LoadAoePlugin()) {
221       return kLiteError;
222     }
223     for (auto &job_type : job_types) {
224       std::cout << "Start to " << kTuneModeMap.at(job_type) << std::endl;
225       std::map<std::string, std::string> global_options_new = global_options;
226       global_options_new["job_type"] = job_type;
227       if (!aoe_instance.AoeInitialize(global_options_new)) {
228         return kLiteError;
229       }
230       uint64_t session_id = 0;
231       if (!aoe_instance.AoeCreateSession(&session_id)) {
232         aoe_instance.AoeFinalize();
233         return kLiteError;
234       }
235       if (session && !aoe_instance.AoeSetGeSession(session_id, session.get())) {
236         aoe_instance.AoeDestroySession(session_id);
237         aoe_instance.AoeFinalize();
238         return kLiteError;
239       }
240       if (!aoe_instance.AoeSetTuningGraph(session_id, *graph)) {
241         aoe_instance.AoeDestroySession(session_id);
242         aoe_instance.AoeFinalize();
243         return kLiteError;
244       }
245       if (!inputs.empty() && !aoe_instance.AoeSetTuningGraphInput(session_id, inputs)) {
246         aoe_instance.AoeDestroySession(session_id);
247         aoe_instance.AoeFinalize();
248         return kLiteError;
249       }
250       if (!aoe_instance.AoeTuningGraph(session_id, tuning_options)) {
251         aoe_instance.AoeDestroySession(session_id);
252         aoe_instance.AoeFinalize();
253         return kLiteError;
254       }
255       aoe_instance.AoeDestroySession(session_id);
256       aoe_instance.AoeFinalize();
257       std::cout << "End " << kTuneModeMap.at(job_type) << std::endl;
258     }
259     return kSuccess;
260   } catch (const std::exception &e) {
261     MS_LOG(ERROR) << "Execute aoe failed: " << e.what();
262   } catch (...) {
263     MS_LOG(ERROR) << "Execute aoe failed.";
264   }
265   return kMCFailed;
266 }
267 
GetAscendDeviceInfo(const std::shared_ptr<Context> & context)268 static std::shared_ptr<AscendDeviceInfo> GetAscendDeviceInfo(const std::shared_ptr<Context> &context) {
269   if (context == nullptr) {
270     return nullptr;
271   }
272   auto &device_infos = context->MutableDeviceInfo();
273   if (device_infos.size() != 1 || device_infos[0] == nullptr) {
274     return nullptr;
275   }
276   auto ascend_info = device_infos[0]->Cast<AscendDeviceInfo>();
277   if (ascend_info == nullptr) {
278     return nullptr;
279   }
280   return ascend_info;
281 }
282 
SetOption(std::map<std::string,std::string> * aoe_options,const std::string & key,const std::map<std::string,std::string> & config_options,std::string get_key="")283 static void SetOption(std::map<std::string, std::string> *aoe_options, const std::string &key,
284                       const std::map<std::string, std::string> &config_options, std::string get_key = "") {
285   if (get_key.empty()) {
286     get_key = key;
287   }
288   auto it = config_options.find(get_key);
289   if (it == config_options.end()) {
290     return;
291   }
292   (*aoe_options)[key] = it->second;
293 }
294 
SetOption(std::map<std::string,std::string> * aoe_options,const std::string & key,const std::function<std::string ()> & func)295 static void SetOption(std::map<std::string, std::string> *aoe_options, const std::string &key,
296                       const std::function<std::string()> &func) {
297   auto option = func();
298   if (!option.empty()) {
299     (*aoe_options)[key] = option;
300   }
301 }
302 
GetAoeGlobalOptions(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)303 std::map<std::string, std::string> AoeApiTuning::GetAoeGlobalOptions(const std::shared_ptr<Context> &context,
304                                                                      const ConfigInfos &config_infos) {
305   // framework, device, precision_mode
306   std::map<std::string, std::string> aoe_options;
307   aoe_options["framework"] = "1";
308   // get options from [acl_option_cfg_param]
309   auto section_it = config_infos.find(lite::kAclOptionParam);
310   if (section_it != config_infos.end()) {
311     auto &options = section_it->second;
312     auto option_it = options.find("precision_mode");
313     if (option_it != options.end()) {
314       aoe_options["precision_mode"] = TransforPrecisionToAcl(option_it->second);
315     }
316   }
317   // get options from AscendDeviceInfo: may parse from [ascend_context] & [acl_option_cfg_param]
318   auto ascend_info = GetAscendDeviceInfo(context);
319   if (ascend_info == nullptr) {
320     MS_LOG(WARNING) << "Failed to get ascend device info from context";
321     return {};
322   }
323   aoe_options["device"] = std::to_string(ascend_info->GetDeviceID());
324   auto precision_mode = ascend_info->GetPrecisionMode();
325   if (!precision_mode.empty()) {
326     aoe_options["precision_mode"] = TransforPrecisionToAcl(precision_mode);
327   }
328   // get options from [ge_session_options]
329   section_it = config_infos.find(lite::kGeSessionOptionsSection);
330   if (section_it != config_infos.end()) {
331     auto &options = section_it->second;
332     SetOption(&aoe_options, "precision_mode", options, "ge.exec.precision_mode");
333   }
334   // get options from [ge_graph_options]
335   section_it = config_infos.find(lite::kGeGraphOptionsSection);
336   if (section_it != config_infos.end()) {
337     auto &options = section_it->second;
338     SetOption(&aoe_options, "precision_mode", options, "ge.exec.precision_mode");
339   }
340   // get options from [aoe_global_options]
341   section_it = config_infos.find(lite::kAoeGlobalOptionsSection);
342   if (section_it != config_infos.end()) {
343     for (auto &option_item : section_it->second) {
344       aoe_options[option_item.first] = option_item.second;
345       MS_LOG(INFO) << "Update global option " << option_item.first << " to " << option_item.second;
346     }
347   }
348   return aoe_options;
349 }
350 
GetAoeTuningOptions(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)351 std::map<std::string, std::string> AoeApiTuning::GetAoeTuningOptions(const std::shared_ptr<Context> &context,
352                                                                      const ConfigInfos &config_infos) {
353   // input_shape,dynamic_batch_size, dynamic_image_size, dynamic_dims & input_format
354   std::map<std::string, std::string> aoe_options;
355   // get options from [acl_option_cfg_param]
356   auto section_it = config_infos.find(lite::kAclOptionParam);
357   if (section_it != config_infos.end()) {
358     auto &options = section_it->second;
359     SetOption(&aoe_options, "input_shape", options);
360     SetOption(&aoe_options, "input_format", options);
361     SetOption(&aoe_options, "dynamic_batch_size", options);
362     SetOption(&aoe_options, "dynamic_image_size", options);
363     SetOption(&aoe_options, "dynamic_dims", options);
364   }
365   // get options for AscendDeviceInfo: may parse from [ascend_context] & [acl_option_cfg_param]
366   auto ascend_info = GetAscendDeviceInfo(context);
367   if (ascend_info == nullptr) {
368     MS_LOG(WARNING) << "Failed to get ascend device info from context";
369     return {};
370   }
371   SetOption(&aoe_options, "input_format", [ascend_info]() { return ascend_info->GetInputFormat(); });
372   SetOption(&aoe_options, "input_shape", [ascend_info]() { return ascend_info->GetInputShape(); });
373   SetOption(&aoe_options, "dynamic_batch_size", [ascend_info]() { return ascend_info->GetDynamicBatchSize(); });
374   SetOption(&aoe_options, "dynamic_image_size", [ascend_info]() { return ascend_info->GetDynamicImageSize(); });
375 
376   // get options from [ge_graph_options]
377   section_it = config_infos.find(lite::kGeGraphOptionsSection);
378   if (section_it != config_infos.end()) {
379     auto &options = section_it->second;
380     SetOption(&aoe_options, "input_shape", options, "ge.inputShape");
381     SetOption(&aoe_options, "dynamic_dims", options, "ge.dynamicDims");
382   }
383   // get options from [aoe_tuning_options]
384   section_it = config_infos.find(lite::kAoeTuningOptionsSection);
385   if (section_it != config_infos.end()) {
386     for (auto &option_item : section_it->second) {
387       aoe_options[option_item.first] = option_item.second;
388       MS_LOG(INFO) << "Update tuning option " << option_item.first << " to " << option_item.second;
389     }
390   }
391   if ((aoe_options.find("dynamic_batch_size") != aoe_options.end() ||
392        aoe_options.find("dynamic_dims") != aoe_options.end()) &&
393       aoe_options.find("input_format") == aoe_options.end()) {
394     aoe_options["input_format"] = "ND";
395   }
396   return aoe_options;
397 }
398 
GetAoeJobType(const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)399 std::vector<std::string> AoeApiTuning::GetAoeJobType(const std::shared_ptr<Context> &context,
400                                                      const ConfigInfos &config_infos) {
401   std::vector<std::string> job_types;
402   // get options from [acl_option_cfg_param]
403   auto section_it = config_infos.find(lite::kAclOptionParam);
404   if (section_it != config_infos.end()) {
405     auto &options = section_it->second;
406     auto option_it = options.find("aoe_mode");
407     if (option_it != options.end()) {
408       auto &option = option_it->second;
409       if (option.find(kOperatorTurning) != std::string::npos) {
410         job_types.push_back(kOperatorTurningIndex);
411       }
412       if (option.find(kSubgraphTurning) != std::string::npos) {
413         job_types.push_back(kSubgraphTurningIndex);
414       }
415     }
416   }
417   // get options from [ascend_context]
418   section_it = config_infos.find(lite::kAscendContextSection);
419   if (section_it != config_infos.end()) {
420     auto &options = section_it->second;
421     auto option_it = options.find("aoe_mode");
422     if (option_it != options.end()) {
423       job_types.clear();
424       auto &option = option_it->second;
425       if (option.find(kOperatorTurning) != std::string::npos) {
426         job_types.push_back(kOperatorTurningIndex);
427       }
428       if (option.find(kSubgraphTurning) != std::string::npos) {
429         job_types.push_back(kSubgraphTurningIndex);
430       }
431     }
432   }
433   if (job_types.size() > 1) {
434     MS_LOG(ERROR) << "Config aoe_mode should only be " << kOperatorTurning << " or " << kSubgraphTurning
435                   << " when provider is ge";
436     return {};
437   }
438   // get options from [aoe_global_options]
439   section_it = config_infos.find(lite::kAoeGlobalOptionsSection);
440   if (section_it != config_infos.end()) {
441     auto &options = section_it->second;
442     auto option_it = options.find("job_type");
443     if (option_it != options.end()) {
444       auto option = option_it->second;
445       if (option != kSubgraphTurningIndex && option != kOperatorTurningIndex) {
446         MS_LOG(ERROR) << "Config job_type should only be " << kOperatorTurningIndex << " or " << kSubgraphTurningIndex
447                       << " when provider is ge";
448         return {};
449       }
450       job_types.clear();
451       job_types.push_back(option_it->second);
452     }
453   }
454   if (job_types.empty()) {
455     MS_LOG(ERROR) << "Option aoe_mode or job_type is not set, option aoe_mode should be in section "
456                   << lite::kAclOptionParam << " or " << lite::kAscendContextSection
457                   << ", job_type should be in section " << lite::kAoeGlobalOptionsSection;
458   }
459   return job_types;
460 }
461 
AoeTurningGraph(const std::shared_ptr<ge::Session> & session,const transform::DfGraphPtr & graph,const std::vector<ge::Tensor> & inputs,const std::shared_ptr<Context> & context,const ConfigInfos & config_infos)462 Status AoeApiTuning::AoeTurningGraph(const std::shared_ptr<ge::Session> &session, const transform::DfGraphPtr &graph,
463                                      const std::vector<ge::Tensor> &inputs, const std::shared_ptr<Context> &context,
464                                      const ConfigInfos &config_infos) {
465   auto global_options = GetAoeGlobalOptions(context, config_infos);
466   auto tuning_options = GetAoeTuningOptions(context, config_infos);
467   auto job_types = GetAoeJobType(context, config_infos);
468   if (job_types.empty()) {
469     return kLiteError;
470   }
471   if (ExecuteAoe(session, graph, inputs, job_types, global_options, tuning_options) != kSuccess) {
472     MS_LOG(ERROR) << "Execute aoe failed";
473     return kLiteError;
474   }
475   return kSuccess;
476 }
477 }  // namespace mindspore
478