• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <sstream>
18 #include <set>
19 #include "mindspore/core/utils/file_utils.h"
20 #include "transform/graph_ir/df_graph_manager.h"
21 #include "transform/graph_ir/aoe_util.h"
22 #include "utils/ms_context.h"
23 #include "pipeline/jit/ps/base.h"
24 #include "utils/phase.h"
25 #ifndef ENABLE_LITE_ACL
26 #include "include/common/utils/python_adapter.h"
27 #endif
28 #include "include/common/utils/compile_cache_context.h"
29 #include "include/common/debug/common.h"
30 
31 namespace mindspore {
32 namespace transform {
DfGraphWrapper(const std::string & name,const int & id,const DfGraphPtr & graph_ptr,const OptionMap & options)33 DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr,
34                                const OptionMap &options)
35     : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {}
36 
DfGraphManager()37 DfGraphManager::DfGraphManager() {
38   graph_id_ = 0;
39   graph_runner_ptr_ = nullptr;
40   sess_ptr_ = nullptr;
41 }
42 
~DfGraphManager()43 DfGraphManager::~DfGraphManager() {
44   // in python first destroy after atexit but in c++ destoy before atexit
45   DeleteGraphRunner();
46   DeleteGeSession();
47   ClearGraph();
48 #ifndef ENABLE_LITE_ACL
49   python_adapter::set_python_env_flag(false);
50 #endif
51 }
52 
GetInstance()53 DfGraphManager &DfGraphManager::GetInstance() {
54   static DfGraphManager instance{};
55   return instance;
56 }
57 
GenerateId()58 int DfGraphManager::GenerateId() {
59   graph_id_++;
60   if (graph_id_ <= 0) {
61     graph_id_ = 1;
62   }
63   MS_LOG(INFO) << "Generate graph Id : " << graph_id_;
64   return graph_id_;
65 }
66 
AddGraph(const std::string & name,const DfGraphPtr & graph_ptr,const OptionMap & options,const bool & is_cloud)67 Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options,
68                                 const bool &is_cloud) {
69   std::lock_guard<std::mutex> lg(lock_);
70   if (name.empty()) {
71     MS_LOG(ERROR) << "The graph name is null, add graph failed";
72     return Status::INVALID_ARGUMENT;
73   }
74 
75   if (graph_ptr == nullptr) {
76     MS_LOG(INFO) << "The new graph {" << name << "}'s pointer is null, cannot add graph.";
77     return Status::INVALID_ARGUMENT;
78   }
79 
80   int id = GenerateId();
81   OptionMap new_options = options;
82   auto ms_context_ptr = MsContext::GetInstance();
83   MS_EXCEPTION_IF_NULL(ms_context_ptr);
84   auto soc_version = ms_context_ptr->ascend_soc_version();
85   if (ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE) != "") {
86     (new_options)["ge.exec.precision_mode"] = ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE);
87     MS_LOG(INFO) << "Set precision_mode " << ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE)
88                  << " by user.";
89   } else if (is_cloud) {
90     if (soc_version == "ascend910b" || soc_version == "ascend910c") {
91       (new_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype";
92       MS_LOG(INFO) << "Set precision_mode must_keep_origin_dtype, soc_version is " << soc_version << ".";
93     } else {
94       (new_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
95       MS_LOG(INFO) << "Set precision_mode allow_fp32_to_fp16, soc_version is " << soc_version << ".";
96     }
97   } else {
98     (new_options)["ge.exec.precision_mode"] = "force_fp16";
99     MS_LOG(INFO) << "Set precision_mode force_fp16, soc_version is " << soc_version << ".";
100   }
101   auto &compile_cache_context = CompileCacheContext::GetInstance();
102   auto init_compile_cache = compile_cache_context.init_compile_cache();
103   auto dep_files_hash = compile_cache_context.CompileCacheDepFilesHash();
104   if (CompileCacheEnable() && init_compile_cache) {
105     auto ge_graph_key = IsEnableRefMode() ? name : std::to_string(id);
106     if (!dep_files_hash.empty()) {
107       ge_graph_key = dep_files_hash + "_" + ge_graph_key;
108     }
109     ge_graph_key = NormalizeString(ge_graph_key);
110     new_options.insert_or_assign(kGeGraphKey, ge_graph_key);
111     auto ge_cache_path = Common::GetCompilerCachePath() + kGeCache;
112     (void)mindspore::FileUtils::CreateNotExistDirs(ge_cache_path, true);
113     new_options.insert_or_assign(kGeGraphCompilerCacheDir, ge_cache_path);
114     MS_LOG(INFO) << "Use GE graph compile cache, GE graph compile cache dir: " << ge_cache_path
115                  << ", the ge.graph_key is " << ge_graph_key;
116   }
117 
118   DfGraphWrapperPtr wrap_ptr = std::make_shared<DfGraphWrapper>(name, id, graph_ptr, new_options);
119   auto ret = graphs_.emplace(name, wrap_ptr);
120   if (!ret.second) {
121     MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!";
122     ret.first->second = wrap_ptr;
123   }
124   MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!";
125   return Status::SUCCESS;
126 }
127 
GetAllGraphs()128 std::vector<DfGraphWrapperPtr> DfGraphManager::GetAllGraphs() {
129   std::lock_guard<std::mutex> lg(lock_);
130   std::vector<DfGraphWrapperPtr> ret;
131   std::stringstream ss;
132   ss << "{ ";
133   for (auto it = graphs_.begin(); it != graphs_.end(); ++it) {
134     ss << it->first << ", ";
135     (void)ret.emplace_back(it->second);
136   }
137   ss << "}";
138   MS_LOG(INFO) << "Return graphs: " << ss.str();
139   return ret;
140 }
GetSavedGraphs()141 std::set<string> DfGraphManager::GetSavedGraphs() { return saved_graphs_; }
142 
AddSavedGraphs(const std::string & id)143 void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); }
144 
GetGraphByName(const std::string & name)145 DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) {
146   std::lock_guard<std::mutex> lg(lock_);
147   if (name.empty()) {
148     MS_LOG(ERROR) << "The graph name is null";
149     return nullptr;
150   }
151 
152   auto it = graphs_.find(name);
153   if (it == graphs_.end()) {
154     MS_LOG(INFO) << "Can't found graph name: " << name;
155     return nullptr;
156   }
157   MS_LOG(INFO) << "Return graph: " << name;
158   return it->second;
159 }
160 
ClearGraph()161 void DfGraphManager::ClearGraph() noexcept {
162   std::lock_guard<std::mutex> lg(lock_);
163   for (const auto &graph_id : graphs_) {
164     MS_LOG(INFO) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
165     if (sess_ptr_ != nullptr &&
166         sess_ptr_->RemoveGraph(static_cast<uint32_t>(graph_id.second->id_)) != ::ge::GRAPH_SUCCESS) {
167       MS_LOG(WARNING) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
168     }
169   }
170   graphs_.clear();
171   anf_graphs_.clear();
172   MS_LOG(INFO) << "Remove all graphs in GraphManager";
173 }
174 
SetAnfGraph(const std::string & name,const AnfGraphPtr & anf_graph_ptr)175 void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) {
176   DfGraphWrapperPtr df_graph = GetGraphByName(name);
177   if (df_graph == nullptr) {
178     MS_LOG(ERROR) << "Can't found graph name: " << name;
179     return;
180   }
181   std::lock_guard<std::mutex> lg(lock_);
182   anf_graphs_[df_graph->id_] = anf_graph_ptr;
183 }
184 
GetAnfGraph(uint32_t graph_id)185 AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) {
186   std::lock_guard<std::mutex> lg(lock_);
187   auto iter = anf_graphs_.find(graph_id);
188   if (iter == anf_graphs_.end()) {
189     MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id;
190     return nullptr;
191   }
192 
193   return iter->second;
194 }
195 
SetGeSession(const std::shared_ptr<::ge::Session> & sess_ptr)196 void DfGraphManager::SetGeSession(const std::shared_ptr<::ge::Session> &sess_ptr) {
197   std::lock_guard<std::mutex> lg(lock_);
198   if (sess_ptr == nullptr) {
199     return;
200   }
201 
202   if (sess_ptr_ == nullptr) {
203     MS_LOG(INFO) << "Add a new Ge Session success";
204   } else {
205     MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!";
206   }
207   sess_ptr_ = sess_ptr;
208 }
209 
GetGeSession()210 std::shared_ptr<::ge::Session> DfGraphManager::GetGeSession() {
211   std::lock_guard<std::mutex> lg(lock_);
212   return sess_ptr_;
213 }
214 
DeleteGeSession()215 void DfGraphManager::DeleteGeSession() noexcept {
216   std::lock_guard<std::mutex> lg(lock_);
217   if (sess_ptr_ == nullptr) {
218     MS_LOG(INFO) << "Ge Session is not exist";
219   } else {
220     for (const auto &graph_id : graphs_) {
221       MS_LOG(INFO) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
222       if (sess_ptr_->RemoveGraph(static_cast<uint32_t>(graph_id.second->id_)) != ::ge::GRAPH_SUCCESS) {
223         MS_LOG(WARNING) << "Remove graph, graph name: " << graph_id.first << ", graph id: " << graph_id.second->id_;
224       }
225     }
226     sess_ptr_ = nullptr;
227     saved_graphs_.clear();
228     MS_LOG(INFO) << "Delete Ge Session success";
229   }
230 }
231 
SetGraphRunner(const std::shared_ptr<transform::GraphRunner> & graph_runner_ptr)232 void DfGraphManager::SetGraphRunner(const std::shared_ptr<transform::GraphRunner> &graph_runner_ptr) noexcept {
233   std::lock_guard<std::mutex> lg(lock_);
234   if (graph_runner_ptr == nullptr) {
235     MS_LOG(WARNING) << "You are adding a empty GraphRunner";
236   }
237 
238   if (graph_runner_ptr_ == nullptr) {
239     MS_LOG(INFO) << "Add a new GraphRunner success";
240   } else {
241     MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!";
242   }
243   graph_runner_ptr_ = graph_runner_ptr;
244 }
245 
GetGraphRunner()246 std::shared_ptr<transform::GraphRunner> DfGraphManager::GetGraphRunner() {
247   std::lock_guard<std::mutex> lg(lock_);
248   return graph_runner_ptr_;
249 }
250 
DeleteGraphRunner()251 void DfGraphManager::DeleteGraphRunner() noexcept {
252   std::lock_guard<std::mutex> lg(lock_);
253   if (graph_runner_ptr_ == nullptr) {
254     MS_LOG(INFO) << "GraphRunner is not exist";
255   } else {
256     graph_runner_ptr_ = nullptr;
257     MS_LOG(INFO) << "Delete GraphRunner success";
258   }
259 }
260 
AoeGeGraph()261 void DfGraphManager::AoeGeGraph() {
262   std::set<string> wait_optimize_graphs_ = AoeUtil::GetInstance().GetWaitOptimizeGraph();
263   if (wait_optimize_graphs_.empty()) {
264     return;
265   }
266   MS_LOG(DEBUG) << "start optimized graph";
267   std::set<string> optimized_graph_names_;
268 #ifndef ENABLE_LITE_ACL
269   py::gil_scoped_release release;
270 #endif
271 
272   for (auto &graph_name : wait_optimize_graphs_) {
273     auto wrapper = GetGraphByName(graph_name);
274     MS_EXCEPTION_IF_NULL(wrapper);
275     if (AoeUtil::GetInstance().IsSaveOptimizedGraph(wrapper->id_)) {
276       continue;
277     }
278     Status status = AoeUtil::GetInstance().AoeOnlineGeGraph(GetGeSession(), wrapper->graph_ptr_);
279     if (status == FAILED) {
280       MS_LOG(ERROR) << "AOE tuning failed, graph name is " << graph_name << " id :" << wrapper->id_;
281       return;
282     }
283     AoeUtil::GetInstance().SaveOptimizedGraph(wrapper->id_);
284     optimized_graph_names_.insert(graph_name);
285     MS_LOG(DEBUG) << "Optimized Graph " << graph_name << " success";
286   }
287   AoeUtil::GetInstance().RemoveWaitOptimizedGraph(optimized_graph_names_);
288   optimized_graph_names_.clear();
289   MS_LOG(DEBUG) << "optimized graph end";
290 }
291 }  // namespace transform
292 }  // namespace mindspore
293