• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/compile_cache_manager.h"
18 #include <vector>
19 #include <algorithm>
20 #include <map>
21 #include <utility>
22 #include <fstream>
23 #include "pipeline/jit/ps/parse/data_converter.h"
24 #include "include/common/utils/parallel_context.h"
25 #include "include/common/debug/common.h"
26 #include "include/common/debug/anf_ir_dump.h"
27 #include "include/common/debug/dump_proto.h"
28 #include "utils/system/sha256.h"
29 #include "include/common/utils/utils.h"
30 #include "frontend/parallel/step_parallel.h"
31 #include "frontend/parallel/tensor_layout/shared_parameter.h"
32 #include "mindspore/core/utils/file_utils.h"
33 
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/cluster/cluster_context.h"
36 #include "include/backend/distributed/ps/ps_context.h"
37 #endif
38 #include "include/common/utils/compile_cache_context.h"
39 #include "include/common/utils/config_manager.h"
40 
41 namespace mindspore {
42 #ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
BuildLayout(const FuncGraphPtr & func_graph,mind_ir::ModelProto * model)43 void BuildLayout(const FuncGraphPtr &func_graph, mind_ir::ModelProto *model) {
44   MS_EXCEPTION_IF_NULL(func_graph);
45   MS_EXCEPTION_IF_NULL(model);
46   std::vector<AnfNodePtr> graph_params = func_graph->parameters();
47   mind_ir::ParallelProto *parallel_proto = model->mutable_parallel();
48   for (auto para : graph_params) {
49     std::string name = std::static_pointer_cast<Parameter>(para)->name();
50     auto tensor_layout = para->user_data<parallel::TensorLayout>();
51     if (tensor_layout == nullptr) {
52       MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
53     } else {
54       mind_ir::LayoutProto *layoutProto = parallel_proto->add_layout();
55 
56       // Get all the information for layput
57       auto device_arrangement = tensor_layout->device_arrangement().array();
58       auto tensor_map = tensor_layout->tensor_map().array();
59       auto slice_shape = tensor_layout->slice_shape().array();
60       int64_t field_size = tensor_layout->get_field_size();
61       bool uniform_split = tensor_layout->uniform_split();
62       std::string opt_shard_group = tensor_layout->opt_shard_group();
63       if (!opt_shard_group.empty()) {
64         slice_shape = tensor_layout->opt_shard_slice_shape();
65       }
66       // Save all information to Layout Proto
67       layoutProto->set_name(name);
68       for (auto device_arrangement_element : device_arrangement) {
69         layoutProto->add_device_arrangement_int(device_arrangement_element);
70       }
71       for (auto tensor_map_element : tensor_map) {
72         layoutProto->add_tensor_map_int(tensor_map_element);
73       }
74       for (auto slice_shape_element : slice_shape) {
75         layoutProto->add_slice_shape_int(slice_shape_element);
76       }
77       layoutProto->set_field_size(field_size);
78       layoutProto->set_uniform_split(uniform_split);
79       layoutProto->set_opt_shard_group(opt_shard_group);
80       auto shared_param = para->user_data<parallel::SharedParameter>();
81       if (shared_param) {
82         layoutProto->set_pipeline_shared(shared_param->pipeline_shared());
83         layoutProto->set_is_send(shared_param->is_send());
84         layoutProto->set_peer_rank(shared_param->peer_rank());
85         layoutProto->set_sr_tag(shared_param->sr_tag());
86       }
87     }
88   }
89 }
90 #endif
91 namespace pipeline {
92 namespace {
GetCompileCacheDir()93 std::string GetCompileCacheDir() {
94   static const std::string user_defined_path = Common::GetUserDefineCachePath();
95   static const uint32_t rank_id = IsStandAlone() ? 0 : GetRank();
96   static const std::string compile_cache_dir = user_defined_path + "rank_" + std::to_string(rank_id);
97   return compile_cache_dir;
98 }
99 
GetGraphCacheDir()100 std::string GetGraphCacheDir() { return GetCompileCacheDir() + "/" + kGraphCacheSubDir; }
101 
GetRole()102 std::string GetRole() {
103 #if defined(__linux__) && defined(WITH_BACKEND)
104   if (distributed::cluster::ClusterContext::instance()->initialized()) {
105     auto node = distributed::cluster::ClusterContext::instance()->node();
106     MS_EXCEPTION_IF_NULL(node);
107     const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
108     MS_EXCEPTION_IF_NULL(cluster_ctx);
109     MS_LOG(INFO) << "Cluster is initialized. This node role is " << cluster_ctx->node_role();
110     return cluster_ctx->node_role();
111   }
112 #endif
113   return "";
114 }
115 
GetCompileCachePath(size_t idx)116 std::string GetCompileCachePath(size_t idx) {
117   return GetGraphCacheDir() + "/" + GetRole() + kCompileCacheFileName + "_" + std::to_string(idx) + kMindIrSuffix;
118 }
119 
GetBackendCompileCachePathWithoutExtension(size_t idx)120 std::string GetBackendCompileCachePathWithoutExtension(size_t idx) {
121   return GetGraphCacheDir() + "/" + GetRole() + kBackendCompileCacheFileName + "_" + std::to_string(idx);
122 }
123 
GetDepFilesHashPath()124 std::string GetDepFilesHashPath() {
125   static const std::string dep_files_hash_path = GetGraphCacheDir() + "/" + GetRole() + kDepFilesHashPath;
126   return dep_files_hash_path;
127 }
128 
GetGroupCkptSavePath(size_t index)129 std::string GetGroupCkptSavePath(size_t index) {
130   auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
131   if (!group_info_save_path.empty()) {
132     return group_info_save_path;
133   }
134   return GetGraphCacheDir() + "/group_" + std::to_string(index) + ".ckpt";
135 }
136 
GetDataQueueNameCachePath(const std::string & data_queue_num)137 std::string GetDataQueueNameCachePath(const std::string &data_queue_num) {
138   std::string queue_name_cache_path =
139     GetGraphCacheDir() + "/" + GetRole() + "_" + data_queue_num + kDataQueueNameCacheFileName;
140   return queue_name_cache_path;
141 }
142 
GetCompileDepFilesHash(const py::list & dep_files)143 std::string GetCompileDepFilesHash(const py::list &dep_files) {
144   MS_LOG(DEBUG) << "Dependency files size: " << dep_files.size();
145   std::vector<std::string> dep_files_path;
146   for (auto dep_file : dep_files) {
147     auto file_path = py::cast<std::string>(dep_file);
148     MS_LOG(DEBUG) << "Dependency file path: " << file_path;
149     (void)dep_files_path.emplace_back(file_path);
150   }
151   std::sort(dep_files_path.begin(), dep_files_path.end());
152   std::string files_hash;
153   for (const auto &path : dep_files_path) {
154     std::string file_hash = system::sha256::GetHashFromFile(path);
155     files_hash += file_hash;
156   }
157   std::string files_hash_hash = system::sha256::GetHashFromString(files_hash);
158   return files_hash_hash;
159 }
160 
GenerateWeightsValueMap(const py::dict & weights)161 std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
162   std::map<string, ValuePtr> ret{};
163   for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
164     auto weight_name = py::cast<std::string>(weight->first);
165     auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
166     ret[weight_name] = weight_value;
167   }
168   return ret;
169 }
170 
LoadFuncGraphFromMindIR(const py::dict & weights,bool has_parallel_info,size_t idx)171 std::pair<FuncGraphPtr, LayoutMap> LoadFuncGraphFromMindIR(const py::dict &weights, bool has_parallel_info,
172                                                            size_t idx) {
173   LayoutMap layout_map;
174   std::string compile_cache_path = GetCompileCachePath(idx);
175   auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
176   if (!realpath.has_value()) {
177     MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed.";
178     return std::make_pair(nullptr, layout_map);
179   }
180   struct stat buffer;
181   if (stat(realpath.value().c_str(), &buffer) != 0) {
182     MS_LOG(WARNING) << "Open the compilation cache file " << realpath.value() << " failed.";
183     return std::make_pair(nullptr, layout_map);
184   }
185   auto ms_context = MsContext::GetInstance();
186   MS_EXCEPTION_IF_NULL(ms_context);
187   ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
188   MindIRLoader mindir_loader;
189   mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
190   mindir_loader.set_has_parallel_info(has_parallel_info);
191   mindspore::HashMap<std::string, AnfNodePtr> name_to_node;
192   auto fg = mindir_loader.LoadMindIR(realpath.value(), &name_to_node);
193   auto &context = CompileCacheContext::GetInstance();
194   context.SetFrontNameToFrontNode(name_to_node);
195   context.SetFrontGraph(fg);
196   context.InsertBackendGraphCachePath(fg, GetBackendCompileCachePathWithoutExtension(idx));
197 
198   if (ms_context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
199     MS_LOG(INFO) << "Cell reuse(@lazy_inline) actually takes effect.";
200   }
201 #if defined(__linux__) && defined(WITH_BACKEND)
202   // compile cache does not support host collective or graph kernel.
203   if (common::UseHostCollective() || ms_context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)) {
204     context.SetRestrictedScenarios(true);
205   }
206 #endif
207   return std::make_pair(fg, mindir_loader.layout_map());
208 }
209 
ExportFuncGraphToMindIR(const FuncGraphPtr & fg,const FuncGraphPtr & layout_fg,size_t idx)210 bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) {
211   std::string compile_cache_path = GetCompileCachePath(idx);
212   auto proto = GenBinaryProto(fg);
213   if (proto == nullptr) {
214     MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
215     return false;
216   }
217 #ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
218   if (layout_fg) {
219     BuildLayout(layout_fg, proto.get());
220   }
221 #endif
222   auto &context = CompileCacheContext::GetInstance();
223   context.SetFrontGraph(fg);
224   context.InsertBackendGraphCachePath(fg, GetBackendCompileCachePathWithoutExtension(idx));
225 #if defined(__linux__) && defined(WITH_BACKEND)
226   // compile cache does not support host collective or graph kernel.
227   auto ms_context = MsContext::GetInstance();
228   MS_EXCEPTION_IF_NULL(ms_context);
229   if (common::UseHostCollective() || ms_context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)) {
230     context.SetRestrictedScenarios(true);
231   }
232 #endif
233   MindIRExporter mindir_exporter;
234   return mindir_exporter.SaveProtoToFile(proto.get(), compile_cache_path);
235 }
236 
ExportDepFilesHash(const std::string & compile_cache_dep_files_hash)237 bool ExportDepFilesHash(const std::string &compile_cache_dep_files_hash) {
238   std::string dep_files_hash_path = GetDepFilesHashPath();
239   auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
240   if (!realpath.has_value()) {
241     MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
242     return false;
243   }
244 
245   ChangeFileMode(realpath.value(), S_IWUSR);
246   std::ofstream fout(realpath.value());
247   if (!fout.is_open()) {
248     MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
249     return false;
250   }
251   fout << compile_cache_dep_files_hash;
252   fout.close();
253   ChangeFileMode(realpath.value(), S_IRUSR);
254   return true;
255 }
256 
ExportDataQueueName(const std::string & dataset_phase,const string & queue_name)257 bool ExportDataQueueName(const std::string &dataset_phase, const string &queue_name) {
258   if (queue_name.empty()) {
259     MS_LOG(INFO) << "Export data queue name in dataset phase: " << dataset_phase << ", queue name: " << queue_name;
260     return true;
261   }
262   MS_LOG(INFO) << "Export data queue name in dataset phase: " << dataset_phase;
263   auto &context = CompileCacheContext::GetInstance();
264   context.set_has_cached_queue_name(true);
265   const auto &filename = GetDataQueueNameCachePath(std::to_string(CompileCacheManager::data_queue_num_));
266   MS_LOG(INFO) << "Export data queue name in file " << filename;
267   nlohmann::json name_json;
268   if (!Common::FileExists(filename)) {
269     name_json[dataset_phase] = queue_name;
270     return Common::SaveStringToFile(filename, name_json.dump());
271   }
272   std::ifstream json_fs(filename);
273   if (!json_fs.good()) {
274     return false;
275   }
276   try {
277     json_fs >> name_json;
278     json_fs.close();
279   } catch (std::exception &e) {
280     MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
281     json_fs.close();
282     std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
283     std::ifstream retry_tmp(filename);
284     if (!retry_tmp.good()) {
285       MS_LOG(EXCEPTION) << "Open json file: " << filename << " error.";
286     }
287     retry_tmp >> name_json;
288     retry_tmp.close();
289   }
290   name_json[dataset_phase] = queue_name;
291   return Common::SaveStringToFile(filename, name_json.dump());
292 }
293 
CreateParallelGroupsByCkptFile(size_t index)294 bool CreateParallelGroupsByCkptFile(size_t index) {
295   const std::string group_ckpt_save_path = GetGroupCkptSavePath(index);
296   auto realpath = Common::CreatePrefixPath(group_ckpt_save_path, true);
297   if (!realpath.has_value()) {
298     MS_LOG(ERROR) << "Get real path of file " << group_ckpt_save_path << " failed.";
299     return false;
300   }
301   std::ifstream f(realpath.value());
302   bool file_is_good = f.good();
303   f.close();
304   if (!file_is_good) {
305     MS_LOG(ERROR) << "Open the group checkpoint file " << realpath.value() << " failed.";
306     return false;
307   }
308   return parallel::CreateGroupsByCkptFile(group_ckpt_save_path);
309 }
310 
GetDataQueueName(const FuncGraphPtr & fg)311 std::string GetDataQueueName(const FuncGraphPtr &fg) {
312   auto cnodes = fg->GetOrderedCnodes();
313   std::string queue_name;
314   for (const auto &cnode : cnodes) {
315     auto prim = GetValuePtr<Primitive>(cnode->input(0));
316     if (prim != nullptr && prim->HasAttr("shared_name")) {
317       StringImmPtr queue_name_ptr = std::dynamic_pointer_cast<StringImm>(prim->GetAttr("shared_name"));
318       queue_name = queue_name_ptr->value();
319       break;
320     }
321   }
322   return queue_name;
323 }
324 }  // namespace
325 
326 size_t CompileCacheManager::data_queue_num_ = 0;
GetCachedDataQueueName(const std::string & dataset_phase)327 std::string CompileCacheManager::GetCachedDataQueueName(const std::string &dataset_phase) {
328   std::string queue_name;
329   if (!CompileCacheEnable()) {
330     return queue_name;
331   }
332   data_queue_num_++;
333   auto &config_mng = ConfigManager::GetInstance();
334   if (config_mng.dataset_phase().empty()) {
335     config_mng.set_dataset_phase(dataset_phase);
336   }
337   // if queue name has cached, we should not get it again from cache file in the same process.
338   auto &context = CompileCacheContext::GetInstance();
339   if (context.has_cached_queue_name()) {
340     return queue_name;
341   }
342   const auto &filename = GetDataQueueNameCachePath(std::to_string(CompileCacheManager::data_queue_num_));
343   MS_LOG(INFO) << "Get data queue name from file " << filename;
344   std::ifstream json_fs(filename);
345   if (!json_fs.good()) {
346     return queue_name;
347   }
348   nlohmann::json name_json;
349   try {
350     json_fs >> name_json;
351     json_fs.close();
352   } catch (std::exception &e) {
353     MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
354     json_fs.close();
355     std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
356     std::ifstream retry_tmp(filename);
357     if (!retry_tmp.good()) {
358       MS_LOG(EXCEPTION) << "Open json file: " << filename << " error.";
359     }
360     retry_tmp >> name_json;
361     retry_tmp.close();
362   }
363   queue_name = name_json[dataset_phase];
364   return queue_name;
365 }
366 
CacheFuncGraph(const FuncGraphPtr & fg,const FuncGraphPtr & layout_fg)367 void CompileCacheManager::CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg) {
368   if (fg == nullptr) {
369     MS_LOG(ERROR) << "The func_graph to be cached is null.";
370     return;
371   }
372 
373   const auto &queue_name = GetDataQueueName(fg);
374   auto dataset_phase = ConfigManager::GetInstance().dataset_phase();
375   if (!ExportDataQueueName(dataset_phase, queue_name)) {
376     MS_LOG(ERROR) << "Failed to cache data queue name: " << queue_name;
377     return;
378   }
379 
380   SetCompileCacheDir(GetCompileCacheDir());
381 
382   if (!ExportFuncGraphToMindIR(fg, layout_fg, compile_cache_id_)) {
383     MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString();
384     return;
385   }
386   if (compile_cache_id_ == 0 && !ExportDepFilesHash(compile_cache_dep_files_hash_)) {
387     MS_LOG(ERROR) << "Failed to cache the dependency files hash";
388   }
389 }
390 
InitCompileCacheHash(const py::list & compile_cache_dep_files)391 void CompileCacheManager::InitCompileCacheHash(const py::list &compile_cache_dep_files) {
392   compile_cache_dep_files_hash_ = GetCompileDepFilesHash(compile_cache_dep_files);
393   auto &context = CompileCacheContext::GetInstance();
394   context.SetCompileCacheDepFilesHash(compile_cache_dep_files_hash_);
395 }
396 
CanLoadCache()397 bool CompileCacheManager::CanLoadCache() {
398   if (compile_cache_dep_files_hash_.empty()) {
399     MS_LOG(ERROR) << "Get current dependency files hash failed.";
400     return false;
401   }
402   std::string dep_files_hash_path = GetDepFilesHashPath();
403   auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
404   if (!realpath.has_value()) {
405     MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
406     return false;
407   }
408   std::fstream input(realpath.value(), std::ios::in | std::ios::binary);
409   if (!input) {
410     MS_LOG(WARNING) << "Open the hash file " << realpath.value() << " failed. The file may not exist."
411                     << ErrnoToString(errno);
412     return false;
413   }
414   std::string checkpoint_hash;
415   input >> checkpoint_hash;
416   if (checkpoint_hash.empty()) {
417     MS_LOG(ERROR) << "Get the compilation dependency files hash from " << realpath.value() << " failed.";
418     return false;
419   }
420   if (checkpoint_hash != compile_cache_dep_files_hash_) {
421     MS_LOG(WARNING) << "The compilation dependency files are changed.";
422     return false;
423   }
424   auto compile_cache_path = GetCompileCachePath(compile_cache_id_);
425   struct stat buffer;
426   if (stat(compile_cache_path.c_str(), &buffer) != 0) {
427     MS_LOG(WARNING) << "Failed to find cache file, execute all the compilation actions.";
428     return false;
429   }
430   return true;
431 }
432 
GetCachedFuncGraph(const FuncGraphManagerPtr & manager,const py::dict & weights,const std::string & queue_name)433 FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights,
434                                                      const std::string &queue_name) {
435   // Determine whether to load parallel information.
436   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
437   bool has_parallel_info = false;
438   if ((parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel)) {
439     if (!CreateParallelGroupsByCkptFile(compile_cache_id_)) {
440       MS_LOG(WARNING) << "Failed to create the parallel groups info. Execute all the compilation actions.";
441       return nullptr;
442     }
443     has_parallel_info = true;
444   }
445   // Load the compilation cache file.
446   auto pair = LoadFuncGraphFromMindIR(weights, has_parallel_info, compile_cache_id_);
447   if (pair.first == nullptr) {
448     MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
449     return nullptr;
450   }
451   auto fg = pair.first;
452   layout_map_ = pair.second;
453 
454   MS_LOG(WARNING) << "Use the compilation cache and execute the backend actions only. Be aware of correctness risks.";
455   FuncGraphManagerPtr mng = fg->manager();
456   if (mng == nullptr) {
457     MS_EXCEPTION_IF_NULL(manager);
458     manager->AddFuncGraph(fg);
459     fg->set_manager(manager);
460   }
461   // The value of attr "shared_name" will changed every time.
462   auto cnodes = fg->GetOrderedCnodes();
463   for (const auto &cnode : cnodes) {
464     auto prim = GetValuePtr<Primitive>(cnode->input(0));
465     if (prim != nullptr && prim->HasAttr("shared_name")) {
466       prim->set_attr("shared_name", MakeValue(queue_name));
467       break;
468     }
469   }
470 #ifdef ENABLE_DUMP_IR
471   auto context = MsContext::GetInstance();
472   MS_EXCEPTION_IF_NULL(context);
473   if (context->CanDump(kIntroductory)) {
474     DumpIR("cache_loaded_graph_" + std::to_string(compile_cache_id_) + ".ir", fg);
475   }
476 #endif
477   return fg;
478 }
479 
InitParallelGroupCkptSaveFile()480 void CompileCacheManager::InitParallelGroupCkptSaveFile() {
481   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
482   if ((parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel)) {
483     parallel::ParallelContext::GetInstance()->set_group_ckpt_save_file(GetGroupCkptSavePath(compile_cache_id_));
484   }
485 }
486 }  // namespace pipeline
487 }  // namespace mindspore
488