1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <sstream>
18 #include <set>
19 #include "mindspore/core/utils/file_utils.h"
20 #include "transform/graph_ir/df_graph_manager.h"
21 #include "transform/graph_ir/aoe_util.h"
22 #include "utils/ms_context.h"
23 #include "pipeline/jit/ps/base.h"
24 #include "utils/phase.h"
25 #ifndef ENABLE_LITE_ACL
26 #include "include/common/utils/python_adapter.h"
27 #endif
28 #include "include/common/utils/compile_cache_context.h"
29 #include "include/common/debug/common.h"
30
31 namespace mindspore {
32 namespace transform {
DfGraphWrapper(const std::string & name,const int & id,const DfGraphPtr & graph_ptr,const OptionMap & options)33 DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr,
34 const OptionMap &options)
35 : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {}
36
DfGraphManager()37 DfGraphManager::DfGraphManager() {
38 graph_id_ = 0;
39 graph_runner_ptr_ = nullptr;
40 sess_ptr_ = nullptr;
41 }
42
~DfGraphManager()43 DfGraphManager::~DfGraphManager() {
44 // in python first destroy after atexit but in c++ destoy before atexit
45 DeleteGraphRunner();
46 DeleteGeSession();
47 ClearGraph();
48 #ifndef ENABLE_LITE_ACL
49 python_adapter::set_python_env_flag(false);
50 #endif
51 }
52
GetInstance()53 DfGraphManager &DfGraphManager::GetInstance() {
54 static DfGraphManager instance{};
55 return instance;
56 }
57
GenerateId()58 int DfGraphManager::GenerateId() {
59 graph_id_++;
60 if (graph_id_ <= 0) {
61 graph_id_ = 1;
62 }
63 MS_LOG(INFO) << "Generate graph Id : " << graph_id_;
64 return graph_id_;
65 }
66
AddGraph(const std::string & name,const DfGraphPtr & graph_ptr,const OptionMap & options,const bool & is_cloud)67 Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options,
68 const bool &is_cloud) {
69 std::lock_guard<std::mutex> lg(lock_);
70 if (name.empty()) {
71 MS_LOG(ERROR) << "The graph name is null, add graph failed";
72 return Status::INVALID_ARGUMENT;
73 }
74
75 if (graph_ptr == nullptr) {
76 MS_LOG(INFO) << "The new graph {" << name << "}'s pointer is null, cannot add graph.";
77 return Status::INVALID_ARGUMENT;
78 }
79
80 int id = GenerateId();
81 OptionMap new_options = options;
82 auto ms_context_ptr = MsContext::GetInstance();
83 MS_EXCEPTION_IF_NULL(ms_context_ptr);
84 auto soc_version = ms_context_ptr->ascend_soc_version();
85 if (ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE) != "") {
86 (new_options)["ge.exec.precision_mode"] = ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE);
87 MS_LOG(INFO) << "Set precision_mode " << ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE)
88 << " by user.";
89 } else if (is_cloud) {
90 if (soc_version == "ascend910b" || soc_version == "ascend910c") {
91 (new_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype";
92 MS_LOG(INFO) << "Set precision_mode must_keep_origin_dtype, soc_version is " << soc_version << ".";
93 } else {
94 (new_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
95 MS_LOG(INFO) << "Set precision_mode allow_fp32_to_fp16, soc_version is " << soc_version << ".";
96 }
97 } else {
98 (new_options)["ge.exec.precision_mode"] = "force_fp16";
99 MS_LOG(INFO) << "Set precision_mode force_fp16, soc_version is " << soc_version << ".";
100 }
101 auto &compile_cache_context = CompileCacheContext::GetInstance();
102 auto init_compile_cache = compile_cache_context.init_compile_cache();
103 auto dep_files_hash = compile_cache_context.CompileCacheDepFilesHash();
104 if (CompileCacheEnable() && init_compile_cache) {
105 auto ge_graph_key = IsEnableRefMode() ? name : std::to_string(id);
106 if (!dep_files_hash.empty()) {
107 ge_graph_key = dep_files_hash + "_" + ge_graph_key;
108 }
109 ge_graph_key = NormalizeString(ge_graph_key);
110 new_options.insert_or_assign(kGeGraphKey, ge_graph_key);
111 auto ge_cache_path = Common::GetCompilerCachePath() + kGeCache;
112 (void)mindspore::FileUtils::CreateNotExistDirs(ge_cache_path, true);
113 new_options.insert_or_assign(kGeGraphCompilerCacheDir, ge_cache_path);
114 MS_LOG(INFO) << "Use GE graph compile cache, GE graph compile cache dir: " << ge_cache_path
115 << ", the ge.graph_key is " << ge_graph_key;
116 }
117
118 DfGraphWrapperPtr wrap_ptr = std::make_shared<DfGraphWrapper>(name, id, graph_ptr, new_options);
119 auto ret = graphs_.emplace(name, wrap_ptr);
120 if (!ret.second) {
121 MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!";
122 ret.first->second = wrap_ptr;
123 }
124 MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!";
125 return Status::SUCCESS;
126 }
127
GetAllGraphs()128 std::vector<DfGraphWrapperPtr> DfGraphManager::GetAllGraphs() {
129 std::lock_guard<std::mutex> lg(lock_);
130 std::vector<DfGraphWrapperPtr> ret;
131 std::stringstream ss;
132 ss << "{ ";
133 for (auto it = graphs_.begin(); it != graphs_.end(); ++it) {
134 ss << it->first << ", ";
135 (void)ret.emplace_back(it->second);
136 }
137 ss << "}";
138 MS_LOG(INFO) << "Return graphs: " << ss.str();
139 return ret;
140 }
GetSavedGraphs()141 std::set<string> DfGraphManager::GetSavedGraphs() { return saved_graphs_; }
142
AddSavedGraphs(const std::string & id)143 void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); }
144
GetGraphByName(const std::string & name)145 DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) {
146 std::lock_guard<std::mutex> lg(lock_);
147 if (name.empty()) {
148 MS_LOG(ERROR) << "The graph name is null";
149 return nullptr;
150 }
151
152 auto it = graphs_.find(name);
153 if (it == graphs_.end()) {
154 MS_LOG(INFO) << "Can't found graph name: " << name;
155 return nullptr;
156 }
157 MS_LOG(INFO) << "Return graph: " << name;
158 return it->second;
159 }
160
ClearGraph()161 void DfGraphManager::ClearGraph() noexcept {
162 std::lock_guard<std::mutex> lg(lock_);
163 for (const auto &graph_id : graphs_) {
164 MS_LOG(INFO) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
165 if (sess_ptr_ != nullptr &&
166 sess_ptr_->RemoveGraph(static_cast<uint32_t>(graph_id.second->id_)) != ::ge::GRAPH_SUCCESS) {
167 MS_LOG(WARNING) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
168 }
169 }
170 graphs_.clear();
171 anf_graphs_.clear();
172 MS_LOG(INFO) << "Remove all graphs in GraphManager";
173 }
174
SetAnfGraph(const std::string & name,const AnfGraphPtr & anf_graph_ptr)175 void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) {
176 DfGraphWrapperPtr df_graph = GetGraphByName(name);
177 if (df_graph == nullptr) {
178 MS_LOG(ERROR) << "Can't found graph name: " << name;
179 return;
180 }
181 std::lock_guard<std::mutex> lg(lock_);
182 anf_graphs_[df_graph->id_] = anf_graph_ptr;
183 }
184
GetAnfGraph(uint32_t graph_id)185 AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) {
186 std::lock_guard<std::mutex> lg(lock_);
187 auto iter = anf_graphs_.find(graph_id);
188 if (iter == anf_graphs_.end()) {
189 MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id;
190 return nullptr;
191 }
192
193 return iter->second;
194 }
195
SetGeSession(const std::shared_ptr<::ge::Session> & sess_ptr)196 void DfGraphManager::SetGeSession(const std::shared_ptr<::ge::Session> &sess_ptr) {
197 std::lock_guard<std::mutex> lg(lock_);
198 if (sess_ptr == nullptr) {
199 return;
200 }
201
202 if (sess_ptr_ == nullptr) {
203 MS_LOG(INFO) << "Add a new Ge Session success";
204 } else {
205 MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!";
206 }
207 sess_ptr_ = sess_ptr;
208 }
209
GetGeSession()210 std::shared_ptr<::ge::Session> DfGraphManager::GetGeSession() {
211 std::lock_guard<std::mutex> lg(lock_);
212 return sess_ptr_;
213 }
214
DeleteGeSession()215 void DfGraphManager::DeleteGeSession() noexcept {
216 std::lock_guard<std::mutex> lg(lock_);
217 if (sess_ptr_ == nullptr) {
218 MS_LOG(INFO) << "Ge Session is not exist";
219 } else {
220 for (const auto &graph_id : graphs_) {
221 MS_LOG(INFO) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
222 if (sess_ptr_->RemoveGraph(static_cast<uint32_t>(graph_id.second->id_)) != ::ge::GRAPH_SUCCESS) {
223 MS_LOG(WARNING) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
224 }
225 }
226 sess_ptr_ = nullptr;
227 saved_graphs_.clear();
228 MS_LOG(INFO) << "Delete Ge Session success";
229 }
230 }
231
SetGraphRunner(const std::shared_ptr<transform::GraphRunner> & graph_runner_ptr)232 void DfGraphManager::SetGraphRunner(const std::shared_ptr<transform::GraphRunner> &graph_runner_ptr) noexcept {
233 std::lock_guard<std::mutex> lg(lock_);
234 if (graph_runner_ptr == nullptr) {
235 MS_LOG(WARNING) << "You are adding a empty GraphRunner";
236 }
237
238 if (graph_runner_ptr_ == nullptr) {
239 MS_LOG(INFO) << "Add a new GraphRunner success";
240 } else {
241 MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!";
242 }
243 graph_runner_ptr_ = graph_runner_ptr;
244 }
245
GetGraphRunner()246 std::shared_ptr<transform::GraphRunner> DfGraphManager::GetGraphRunner() {
247 std::lock_guard<std::mutex> lg(lock_);
248 return graph_runner_ptr_;
249 }
250
DeleteGraphRunner()251 void DfGraphManager::DeleteGraphRunner() noexcept {
252 std::lock_guard<std::mutex> lg(lock_);
253 if (graph_runner_ptr_ == nullptr) {
254 MS_LOG(INFO) << "GraphRunner is not exist";
255 } else {
256 graph_runner_ptr_ = nullptr;
257 MS_LOG(INFO) << "Delete GraphRunner success";
258 }
259 }
260
AoeGeGraph()261 void DfGraphManager::AoeGeGraph() {
262 std::set<string> wait_optimize_graphs_ = AoeUtil::GetInstance().GetWaitOptimizeGraph();
263 if (wait_optimize_graphs_.empty()) {
264 return;
265 }
266 MS_LOG(DEBUG) << "start optimized graph";
267 std::set<string> optimized_graph_names_;
268 #ifndef ENABLE_LITE_ACL
269 py::gil_scoped_release release;
270 #endif
271
272 for (auto &graph_name : wait_optimize_graphs_) {
273 auto wrapper = GetGraphByName(graph_name);
274 MS_EXCEPTION_IF_NULL(wrapper);
275 if (AoeUtil::GetInstance().IsSaveOptimizedGraph(wrapper->id_)) {
276 continue;
277 }
278 Status status = AoeUtil::GetInstance().AoeOnlineGeGraph(GetGeSession(), wrapper->graph_ptr_);
279 if (status == FAILED) {
280 MS_LOG(ERROR) << "AOE tuning failed, graph name is " << graph_name << " id :" << wrapper->id_;
281 return;
282 }
283 AoeUtil::GetInstance().SaveOptimizedGraph(wrapper->id_);
284 optimized_graph_names_.insert(graph_name);
285 MS_LOG(DEBUG) << "Optimized Graph " << graph_name << " success";
286 }
287 AoeUtil::GetInstance().RemoveWaitOptimizedGraph(optimized_graph_names_);
288 optimized_graph_names_.clear();
289 MS_LOG(DEBUG) << "optimized graph end";
290 }
291 } // namespace transform
292 } // namespace mindspore
293