• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <dlfcn.h>
17 #include <cxxabi.h>
18 #include <set>
19 #include <string>
20 #include "include/common/debug/common.h"
21 #include "transform/graph_ir/aoe_util.h"
22 #include "utils/file_utils.h"
23 #include "utils/ms_context.h"
24 #include "transform/symbol/acl_base_symbol.h"
25 #include "transform/symbol/symbol_utils.h"
26 
27 namespace mindspore {
28 namespace transform {
29 namespace AoeOptions {
30 const ::ge::AscendString JOB_TYPE = ::ge::AscendString("job_type");
31 const ::ge::AscendString FRAMEWORK = ::ge::AscendString("framework");
32 const ::ge::AscendString LOG_LEVEL = ::ge::AscendString("log");
33 const ::ge::AscendString PRECISION_MODE = ::ge::AscendString("precision_mode");
34 }  // namespace AoeOptions
35 
IsAscendServer()36 bool IsAscendServer() {
37   auto ms_context = MsContext::GetInstance();
38   MS_EXCEPTION_IF_NULL(ms_context);
39   return ms_context->ascend_soc_version().find("ascend910") != std::string::npos;
40 }
41 
AoeUtil()42 AoeUtil::AoeUtil() : initialize_(false) {}
43 
~AoeUtil()44 AoeUtil::~AoeUtil() { MS_LOG(INFO) << "release aoeutil success."; }
45 
Initialize()46 void AoeUtil::Initialize() {
47   if (initialize_) {
48     MS_LOG(INFO) << "Aoe already initialized.";
49     return;
50   }
51   if (IsAscendServer()) {
52     std::string ascend_path = GetAscendPath();
53     auto ld_library_path = common::GetEnv("LD_LIBRARY_PATH");
54     ld_library_path = ascend_path + "lib64:" + ld_library_path;
55     common::SetEnv("LD_LIBRARY_PATH", ld_library_path.c_str());
56     std::string aoe_plugin_path = "lib64/libaoe_tuning.so";
57     auto plugin_path = ascend_path + aoe_plugin_path;
58     auto ret = access(plugin_path.c_str(), F_OK);
59     if (ret != 0) {
60       MS_LOG(WARNING) << "plugin " << plugin_path << " not exist";
61       return;
62     }
63 
64     const std::vector<std::string> depend_libs = {"libopat.so", "libaoe_plugin.so", "libparser_common.so"};
65     for (const auto &dep_lib : depend_libs) {
66       auto dep_lip_path = ascend_path + "lib64/" + dep_lib;
67       auto dep_handler = dlopen(dep_lip_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
68       if (dep_handler != nullptr) {
69         depend_handler_.push_back(dep_handler);
70       } else {
71         MS_LOG(INFO) << "Cannot dlopen " << dep_lip_path << ", result = " << GetDlErrorMsg()
72                      << ", it can be ignored if not use aoe.";
73       }
74     }
75 
76     plugin_handle_ = dlopen(plugin_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
77     if (plugin_handle_ == nullptr) {
78       MS_LOG(INFO) << "Cannot dlopen " << plugin_path << ", result = " << GetDlErrorMsg()
79                    << ", it can be ignored if not use aoe.";
80       return;
81     }
82     MS_LOG(INFO) << "load " << aoe_plugin_path << " success";
83     aoe_initialize_ = DlsymFuncObj(AoeInitialize, plugin_handle_);
84     aoe_finalize_ = DlsymFuncObj(AoeFinalize, plugin_handle_);
85     aoe_create_session_ = DlsymFuncObj(AoeCreateSession, plugin_handle_);
86     aoe_set_ge_gession_ = DlsymFuncObj(AoeSetGeSession, plugin_handle_);
87     aoe_set_tuning_graph_ = DlsymFuncObj(AoeSetTuningGraph, plugin_handle_);
88     aoe_tuning_graph_ = DlsymFuncObj(AoeTuningGraph, plugin_handle_);
89     aoe_destroy_session_ = DlsymFuncObj(AoeDestroySession, plugin_handle_);
90     auto ms_context = MsContext::GetInstance();
91     std::string aoe_job_type = ms_context->get_param<std::string>(MS_CTX_AOE_JOB_TYPE);
92     std::map<::ge::AscendString, ::ge::AscendString> globalOptions = {
93       {AoeOptions::JOB_TYPE, ::ge::AscendString(aoe_job_type.c_str())}};
94     const AoeStatus status = aoe_initialize_(globalOptions);
95     if (status != AOE_SUCCESS) {
96       MS_LOG(ERROR) << "AoeInitialize failed. Please refer to 'Ascend Optimization Engine' at "
97                     << "https://www.mindspore.cn to set environment variables.";
98     }
99     MS_LOG(INFO) << "AoeInitialize success.";
100     initialize_ = true;
101   }
102 }
103 
Destroy()104 void AoeUtil::Destroy() {
105   if (!initialize_) {
106     MS_LOG(WARNING) << "AOE not initialize, stop destroy";
107     return;
108   }
109   if (IsAscendServer()) {
110     try {
111       const AoeStatus status = aoe_finalize_();
112       if (status != AOE_SUCCESS) {
113         MS_LOG(ERROR) << "AoeFinalize failed. status is " << status;
114       }
115     } catch (const std::exception &e) {
116       MS_LOG(ERROR) << "Error occurred when exec aoe finalize. Error:" << e.what();
117     } catch (...) {
118       std::string exName(abi::__cxa_current_exception_type()->name());
119       MS_LOG(ERROR) << "Error occurred when  exec aoe finalize. Exception name: " << exName;
120     }
121   }
122   if (plugin_handle_ == nullptr) {
123     return;
124   }
125   aoe_initialize_ = nullptr;
126   aoe_finalize_ = nullptr;
127   aoe_create_session_ = nullptr;
128   aoe_set_ge_gession_ = nullptr;
129   aoe_set_tuning_graph_ = nullptr;
130   aoe_tuning_graph_ = nullptr;
131   aoe_destroy_session_ = nullptr;
132   MS_LOG(INFO) << "AoeFinalization success.";
133   for (const auto &dep_handler : depend_handler_) {
134     (void)dlclose(dep_handler);
135   }
136   (void)dlclose(plugin_handle_);
137   plugin_handle_ = nullptr;
138   initialize_ = false;
139 }
140 
GetInstance()141 AoeUtil &AoeUtil::GetInstance() {
142   static AoeUtil instance{};
143   return instance;
144 }
145 
AoeGeGraph(::ge::Session * ge_session,const transform::DfGraphPtr & graph,const std::map<::ge::AscendString,::ge::AscendString> & tuningOptions) const146 Status AoeUtil::AoeGeGraph(::ge::Session *ge_session, const transform::DfGraphPtr &graph,
147                            const std::map<::ge::AscendString, ::ge::AscendString> &tuningOptions) const {
148   uint64_t sessionId = 0;
149   AoeStatus status = aoe_create_session_(sessionId);
150   if (status != AOE_SUCCESS) {
151     MS_LOG(ERROR) << "AoeCreateSession failed. error code:" << status;
152     return FAILED;
153   }
154   MS_LOG(DEBUG) << "AoeCreateSession success.";
155 
156   status = aoe_set_ge_gession_(sessionId, ge_session);
157   if (status != AOE_SUCCESS) {
158     MS_LOG(ERROR) << "AoeSetGeSession failed. error code:" << status;
159     return FAILED;
160   }
161   MS_LOG(DEBUG) << "->AoeSetGeSession success.";
162 
163   status = aoe_set_tuning_graph_(sessionId, *graph);
164   if (status != AOE_SUCCESS) {
165     MS_LOG(ERROR) << "AoeSetGraph failed. error code:" << status;
166     return FAILED;
167   }
168   MS_LOG(DEBUG) << "->AoeSetGraph success.";
169 
170   status = aoe_tuning_graph_(sessionId, tuningOptions);
171   if (status != AOE_SUCCESS && status != AOE_ERROR_NON_OPTIMIZE_GRAPH) {
172     MS_LOG(ERROR) << "AoeTuningGraph failed. error code:" << status;
173     (void)aoe_destroy_session_(sessionId);
174     return FAILED;
175   }
176   MS_LOG(DEBUG) << "->AoeTuningGraph success.";
177 
178   status = aoe_destroy_session_(sessionId);
179   if (status != AOE_SUCCESS) {
180     MS_LOG(ERROR) << "AoeDestroySession failed. error code:" << status;
181     return FAILED;
182   }
183   return SUCCESS;
184 }
185 
AoeOnlineGeGraph(const std::shared_ptr<::ge::Session> & ge_session,const transform::DfGraphPtr & graph) const186 Status AoeUtil::AoeOnlineGeGraph(const std::shared_ptr<::ge::Session> &ge_session,
187                                  const transform::DfGraphPtr &graph) const {
188   MS_LOG(DEBUG) << "AoeOnlineGeGraph start.";
189   if (!initialize_) {
190     MS_LOG(WARNING) << "AOE not initialize";
191     return FAILED;
192   }
193   if (ge_session == nullptr) {
194     MS_LOG(ERROR) << "sess is null";
195     return FAILED;
196   }
197   auto ms_context = MsContext::GetInstance();
198   MS_EXCEPTION_IF_NULL(ms_context);
199   const auto &soc_version = ms_context->ascend_soc_version();
200   ::ge::AscendString precision_mode = "allow_fp32_to_fp16";
201   if (soc_version == "ascend910b" || soc_version == "ascend910c") {
202     precision_mode = "must_keep_origin_dtype";
203   }
204 
205   std::map<::ge::AscendString, ::ge::AscendString> tuneOptions = {
206     {AoeOptions::FRAMEWORK, ::ge::AscendString("1")},
207     {AoeOptions::PRECISION_MODE, precision_mode},
208     {AoeOptions::LOG_LEVEL, ::ge::AscendString("error")},
209   };
210 
211   if (AoeGeGraph(ge_session.get(), graph, tuneOptions) != SUCCESS) {
212     MS_LOG(ERROR) << "Failed to call Aoe online tuning.";
213     return FAILED;
214   }
215 
216   MS_LOG(DEBUG) << "AoeTuningGraph success.";
217   return SUCCESS;
218 }
219 
SaveOptimizedGraph(const int32_t & graph_id)220 void AoeUtil::SaveOptimizedGraph(const int32_t &graph_id) { optimized_graphs_id_.insert(graph_id); }
221 
IsSaveOptimizedGraph(const int32_t & graph_id) const222 bool AoeUtil::IsSaveOptimizedGraph(const int32_t &graph_id) const {
223   auto iter_find = optimized_graphs_id_.find(graph_id);
224   if (iter_find != optimized_graphs_id_.end()) {
225     return true;
226   }
227   return false;
228 }
229 
RemoveWaitOptimizedGraph(const std::set<std::string> & optimized_graph_names)230 void AoeUtil::RemoveWaitOptimizedGraph(const std::set<std::string> &optimized_graph_names) {
231   for (auto &graph_name : optimized_graph_names) {
232     if (auto remove_iter = wait_optimize_graphs_.find(graph_name); remove_iter != wait_optimize_graphs_.end())
233       (void)wait_optimize_graphs_.erase(remove_iter);
234   }
235   if (!wait_optimize_graphs_.empty()) {
236     MS_LOG(WARNING) << "optimize_graphs_ is not empty";
237   }
238 }
239 
AddOptimizeGraph(const std::string & graph_name)240 void AoeUtil::AddOptimizeGraph(const std::string &graph_name) { wait_optimize_graphs_.insert(graph_name); }
241 
GetWaitOptimizeGraph() const242 std::set<std::string> AoeUtil::GetWaitOptimizeGraph() const { return wait_optimize_graphs_; }
243 
SetOfflineEnvDumpGeGraph()244 void AoeUtil::SetOfflineEnvDumpGeGraph() {
245   auto file_path = GetSaveGraphsPathName("aoe_dump");
246   auto real_path = FileUtils::CreateNotExistDirs(file_path, true);
247   if (!real_path.has_value()) {
248     MS_LOG(WARNING) << "fail to create aoe dump dir " << real_path.value();
249     return;
250   }
251   common::SetEnv("DUMP_GE_GRAPH", "1");
252   common::SetEnv("DUMP_GRAPH_LEVEL", "4");
253   common::SetEnv("DUMP_GRAPH_PATH", real_path.value().c_str());
254 }
255 }  // namespace transform
256 }  // namespace mindspore
257