1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <dlfcn.h>
17 #include <cxxabi.h>
18 #include <set>
19 #include <string>
20 #include "include/common/debug/common.h"
21 #include "transform/graph_ir/aoe_util.h"
22 #include "utils/file_utils.h"
23 #include "utils/ms_context.h"
24 #include "transform/symbol/acl_base_symbol.h"
25 #include "transform/symbol/symbol_utils.h"
26
27 namespace mindspore {
28 namespace transform {
29 namespace AoeOptions {
30 const ::ge::AscendString JOB_TYPE = ::ge::AscendString("job_type");
31 const ::ge::AscendString FRAMEWORK = ::ge::AscendString("framework");
32 const ::ge::AscendString LOG_LEVEL = ::ge::AscendString("log");
33 const ::ge::AscendString PRECISION_MODE = ::ge::AscendString("precision_mode");
34 } // namespace AoeOptions
35
IsAscendServer()36 bool IsAscendServer() {
37 auto ms_context = MsContext::GetInstance();
38 MS_EXCEPTION_IF_NULL(ms_context);
39 return ms_context->ascend_soc_version().find("ascend910") != std::string::npos;
40 }
41
AoeUtil()42 AoeUtil::AoeUtil() : initialize_(false) {}
43
~AoeUtil()44 AoeUtil::~AoeUtil() { MS_LOG(INFO) << "release aoeutil success."; }
45
Initialize()46 void AoeUtil::Initialize() {
47 if (initialize_) {
48 MS_LOG(INFO) << "Aoe already initialized.";
49 return;
50 }
51 if (IsAscendServer()) {
52 std::string ascend_path = GetAscendPath();
53 auto ld_library_path = common::GetEnv("LD_LIBRARY_PATH");
54 ld_library_path = ascend_path + "lib64:" + ld_library_path;
55 common::SetEnv("LD_LIBRARY_PATH", ld_library_path.c_str());
56 std::string aoe_plugin_path = "lib64/libaoe_tuning.so";
57 auto plugin_path = ascend_path + aoe_plugin_path;
58 auto ret = access(plugin_path.c_str(), F_OK);
59 if (ret != 0) {
60 MS_LOG(WARNING) << "plugin " << plugin_path << " not exist";
61 return;
62 }
63
64 const std::vector<std::string> depend_libs = {"libopat.so", "libaoe_plugin.so", "libparser_common.so"};
65 for (const auto &dep_lib : depend_libs) {
66 auto dep_lip_path = ascend_path + "lib64/" + dep_lib;
67 auto dep_handler = dlopen(dep_lip_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
68 if (dep_handler != nullptr) {
69 depend_handler_.push_back(dep_handler);
70 } else {
71 MS_LOG(INFO) << "Cannot dlopen " << dep_lip_path << ", result = " << GetDlErrorMsg()
72 << ", it can be ignored if not use aoe.";
73 }
74 }
75
76 plugin_handle_ = dlopen(plugin_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
77 if (plugin_handle_ == nullptr) {
78 MS_LOG(INFO) << "Cannot dlopen " << plugin_path << ", result = " << GetDlErrorMsg()
79 << ", it can be ignored if not use aoe.";
80 return;
81 }
82 MS_LOG(INFO) << "load " << aoe_plugin_path << " success";
83 aoe_initialize_ = DlsymFuncObj(AoeInitialize, plugin_handle_);
84 aoe_finalize_ = DlsymFuncObj(AoeFinalize, plugin_handle_);
85 aoe_create_session_ = DlsymFuncObj(AoeCreateSession, plugin_handle_);
86 aoe_set_ge_gession_ = DlsymFuncObj(AoeSetGeSession, plugin_handle_);
87 aoe_set_tuning_graph_ = DlsymFuncObj(AoeSetTuningGraph, plugin_handle_);
88 aoe_tuning_graph_ = DlsymFuncObj(AoeTuningGraph, plugin_handle_);
89 aoe_destroy_session_ = DlsymFuncObj(AoeDestroySession, plugin_handle_);
90 auto ms_context = MsContext::GetInstance();
91 std::string aoe_job_type = ms_context->get_param<std::string>(MS_CTX_AOE_JOB_TYPE);
92 std::map<::ge::AscendString, ::ge::AscendString> globalOptions = {
93 {AoeOptions::JOB_TYPE, ::ge::AscendString(aoe_job_type.c_str())}};
94 const AoeStatus status = aoe_initialize_(globalOptions);
95 if (status != AOE_SUCCESS) {
96 MS_LOG(ERROR) << "AoeInitialize failed. Please refer to 'Ascend Optimization Engine' at "
97 << "https://www.mindspore.cn to set environment variables.";
98 }
99 MS_LOG(INFO) << "AoeInitialize success.";
100 initialize_ = true;
101 }
102 }
103
Destroy()104 void AoeUtil::Destroy() {
105 if (!initialize_) {
106 MS_LOG(WARNING) << "AOE not initialize, stop destroy";
107 return;
108 }
109 if (IsAscendServer()) {
110 try {
111 const AoeStatus status = aoe_finalize_();
112 if (status != AOE_SUCCESS) {
113 MS_LOG(ERROR) << "AoeFinalize failed. status is " << status;
114 }
115 } catch (const std::exception &e) {
116 MS_LOG(ERROR) << "Error occurred when exec aoe finalize. Error:" << e.what();
117 } catch (...) {
118 std::string exName(abi::__cxa_current_exception_type()->name());
119 MS_LOG(ERROR) << "Error occurred when exec aoe finalize. Exception name: " << exName;
120 }
121 }
122 if (plugin_handle_ == nullptr) {
123 return;
124 }
125 aoe_initialize_ = nullptr;
126 aoe_finalize_ = nullptr;
127 aoe_create_session_ = nullptr;
128 aoe_set_ge_gession_ = nullptr;
129 aoe_set_tuning_graph_ = nullptr;
130 aoe_tuning_graph_ = nullptr;
131 aoe_destroy_session_ = nullptr;
132 MS_LOG(INFO) << "AoeFinalization success.";
133 for (const auto &dep_handler : depend_handler_) {
134 (void)dlclose(dep_handler);
135 }
136 (void)dlclose(plugin_handle_);
137 plugin_handle_ = nullptr;
138 initialize_ = false;
139 }
140
GetInstance()141 AoeUtil &AoeUtil::GetInstance() {
142 static AoeUtil instance{};
143 return instance;
144 }
145
AoeGeGraph(::ge::Session * ge_session,const transform::DfGraphPtr & graph,const std::map<::ge::AscendString,::ge::AscendString> & tuningOptions) const146 Status AoeUtil::AoeGeGraph(::ge::Session *ge_session, const transform::DfGraphPtr &graph,
147 const std::map<::ge::AscendString, ::ge::AscendString> &tuningOptions) const {
148 uint64_t sessionId = 0;
149 AoeStatus status = aoe_create_session_(sessionId);
150 if (status != AOE_SUCCESS) {
151 MS_LOG(ERROR) << "AoeCreateSession failed. error code:" << status;
152 return FAILED;
153 }
154 MS_LOG(DEBUG) << "AoeCreateSession success.";
155
156 status = aoe_set_ge_gession_(sessionId, ge_session);
157 if (status != AOE_SUCCESS) {
158 MS_LOG(ERROR) << "AoeSetGeSession failed. error code:" << status;
159 return FAILED;
160 }
161 MS_LOG(DEBUG) << "->AoeSetGeSession success.";
162
163 status = aoe_set_tuning_graph_(sessionId, *graph);
164 if (status != AOE_SUCCESS) {
165 MS_LOG(ERROR) << "AoeSetGraph failed. error code:" << status;
166 return FAILED;
167 }
168 MS_LOG(DEBUG) << "->AoeSetGraph success.";
169
170 status = aoe_tuning_graph_(sessionId, tuningOptions);
171 if (status != AOE_SUCCESS && status != AOE_ERROR_NON_OPTIMIZE_GRAPH) {
172 MS_LOG(ERROR) << "AoeTuningGraph failed. error code:" << status;
173 (void)aoe_destroy_session_(sessionId);
174 return FAILED;
175 }
176 MS_LOG(DEBUG) << "->AoeTuningGraph success.";
177
178 status = aoe_destroy_session_(sessionId);
179 if (status != AOE_SUCCESS) {
180 MS_LOG(ERROR) << "AoeDestroySession failed. error code:" << status;
181 return FAILED;
182 }
183 return SUCCESS;
184 }
185
AoeOnlineGeGraph(const std::shared_ptr<::ge::Session> & ge_session,const transform::DfGraphPtr & graph) const186 Status AoeUtil::AoeOnlineGeGraph(const std::shared_ptr<::ge::Session> &ge_session,
187 const transform::DfGraphPtr &graph) const {
188 MS_LOG(DEBUG) << "AoeOnlineGeGraph start.";
189 if (!initialize_) {
190 MS_LOG(WARNING) << "AOE not initialize";
191 return FAILED;
192 }
193 if (ge_session == nullptr) {
194 MS_LOG(ERROR) << "sess is null";
195 return FAILED;
196 }
197 auto ms_context = MsContext::GetInstance();
198 MS_EXCEPTION_IF_NULL(ms_context);
199 const auto &soc_version = ms_context->ascend_soc_version();
200 ::ge::AscendString precision_mode = "allow_fp32_to_fp16";
201 if (soc_version == "ascend910b" || soc_version == "ascend910c") {
202 precision_mode = "must_keep_origin_dtype";
203 }
204
205 std::map<::ge::AscendString, ::ge::AscendString> tuneOptions = {
206 {AoeOptions::FRAMEWORK, ::ge::AscendString("1")},
207 {AoeOptions::PRECISION_MODE, precision_mode},
208 {AoeOptions::LOG_LEVEL, ::ge::AscendString("error")},
209 };
210
211 if (AoeGeGraph(ge_session.get(), graph, tuneOptions) != SUCCESS) {
212 MS_LOG(ERROR) << "Failed to call Aoe online tuning.";
213 return FAILED;
214 }
215
216 MS_LOG(DEBUG) << "AoeTuningGraph success.";
217 return SUCCESS;
218 }
219
SaveOptimizedGraph(const int32_t & graph_id)220 void AoeUtil::SaveOptimizedGraph(const int32_t &graph_id) { optimized_graphs_id_.insert(graph_id); }
221
IsSaveOptimizedGraph(const int32_t & graph_id) const222 bool AoeUtil::IsSaveOptimizedGraph(const int32_t &graph_id) const {
223 auto iter_find = optimized_graphs_id_.find(graph_id);
224 if (iter_find != optimized_graphs_id_.end()) {
225 return true;
226 }
227 return false;
228 }
229
RemoveWaitOptimizedGraph(const std::set<std::string> & optimized_graph_names)230 void AoeUtil::RemoveWaitOptimizedGraph(const std::set<std::string> &optimized_graph_names) {
231 for (auto &graph_name : optimized_graph_names) {
232 if (auto remove_iter = wait_optimize_graphs_.find(graph_name); remove_iter != wait_optimize_graphs_.end())
233 (void)wait_optimize_graphs_.erase(remove_iter);
234 }
235 if (!wait_optimize_graphs_.empty()) {
236 MS_LOG(WARNING) << "optimize_graphs_ is not empty";
237 }
238 }
239
AddOptimizeGraph(const std::string & graph_name)240 void AoeUtil::AddOptimizeGraph(const std::string &graph_name) { wait_optimize_graphs_.insert(graph_name); }
241
GetWaitOptimizeGraph() const242 std::set<std::string> AoeUtil::GetWaitOptimizeGraph() const { return wait_optimize_graphs_; }
243
SetOfflineEnvDumpGeGraph()244 void AoeUtil::SetOfflineEnvDumpGeGraph() {
245 auto file_path = GetSaveGraphsPathName("aoe_dump");
246 auto real_path = FileUtils::CreateNotExistDirs(file_path, true);
247 if (!real_path.has_value()) {
248 MS_LOG(WARNING) << "fail to create aoe dump dir " << real_path.value();
249 return;
250 }
251 common::SetEnv("DUMP_GE_GRAPH", "1");
252 common::SetEnv("DUMP_GRAPH_LEVEL", "4");
253 common::SetEnv("DUMP_GRAPH_PATH", real_path.value().c_str());
254 }
255 } // namespace transform
256 } // namespace mindspore
257