1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/compile_cache_manager.h"
18 #include <vector>
19 #include <algorithm>
20 #include <map>
21 #include <utility>
22 #include <fstream>
23 #include "pipeline/jit/ps/parse/data_converter.h"
24 #include "include/common/utils/parallel_context.h"
25 #include "include/common/debug/common.h"
26 #include "include/common/debug/anf_ir_dump.h"
27 #include "include/common/debug/dump_proto.h"
28 #include "utils/system/sha256.h"
29 #include "include/common/utils/utils.h"
30 #include "frontend/parallel/step_parallel.h"
31 #include "frontend/parallel/tensor_layout/shared_parameter.h"
32 #include "mindspore/core/utils/file_utils.h"
33
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/cluster/cluster_context.h"
36 #include "include/backend/distributed/ps/ps_context.h"
37 #endif
38 #include "include/common/utils/compile_cache_context.h"
39 #include "include/common/utils/config_manager.h"
40
41 namespace mindspore {
42 #ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
BuildLayout(const FuncGraphPtr & func_graph,mind_ir::ModelProto * model)43 void BuildLayout(const FuncGraphPtr &func_graph, mind_ir::ModelProto *model) {
44 MS_EXCEPTION_IF_NULL(func_graph);
45 MS_EXCEPTION_IF_NULL(model);
46 std::vector<AnfNodePtr> graph_params = func_graph->parameters();
47 mind_ir::ParallelProto *parallel_proto = model->mutable_parallel();
48 for (auto para : graph_params) {
49 std::string name = std::static_pointer_cast<Parameter>(para)->name();
50 auto tensor_layout = para->user_data<parallel::TensorLayout>();
51 if (tensor_layout == nullptr) {
52 MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
53 } else {
54 mind_ir::LayoutProto *layoutProto = parallel_proto->add_layout();
55
56 // Get all the information for layput
57 auto device_arrangement = tensor_layout->device_arrangement().array();
58 auto tensor_map = tensor_layout->tensor_map().array();
59 auto slice_shape = tensor_layout->slice_shape().array();
60 int64_t field_size = tensor_layout->get_field_size();
61 bool uniform_split = tensor_layout->uniform_split();
62 std::string opt_shard_group = tensor_layout->opt_shard_group();
63 if (!opt_shard_group.empty()) {
64 slice_shape = tensor_layout->opt_shard_slice_shape();
65 }
66 // Save all information to Layout Proto
67 layoutProto->set_name(name);
68 for (auto device_arrangement_element : device_arrangement) {
69 layoutProto->add_device_arrangement_int(device_arrangement_element);
70 }
71 for (auto tensor_map_element : tensor_map) {
72 layoutProto->add_tensor_map_int(tensor_map_element);
73 }
74 for (auto slice_shape_element : slice_shape) {
75 layoutProto->add_slice_shape_int(slice_shape_element);
76 }
77 layoutProto->set_field_size(field_size);
78 layoutProto->set_uniform_split(uniform_split);
79 layoutProto->set_opt_shard_group(opt_shard_group);
80 auto shared_param = para->user_data<parallel::SharedParameter>();
81 if (shared_param) {
82 layoutProto->set_pipeline_shared(shared_param->pipeline_shared());
83 layoutProto->set_is_send(shared_param->is_send());
84 layoutProto->set_peer_rank(shared_param->peer_rank());
85 layoutProto->set_sr_tag(shared_param->sr_tag());
86 }
87 }
88 }
89 }
90 #endif
91 namespace pipeline {
92 namespace {
GetCompileCacheDir()93 std::string GetCompileCacheDir() {
94 static const std::string user_defined_path = Common::GetUserDefineCachePath();
95 static const uint32_t rank_id = IsStandAlone() ? 0 : GetRank();
96 static const std::string compile_cache_dir = user_defined_path + "rank_" + std::to_string(rank_id);
97 return compile_cache_dir;
98 }
99
GetGraphCacheDir()100 std::string GetGraphCacheDir() { return GetCompileCacheDir() + "/" + kGraphCacheSubDir; }
101
GetRole()102 std::string GetRole() {
103 #if defined(__linux__) && defined(WITH_BACKEND)
104 if (distributed::cluster::ClusterContext::instance()->initialized()) {
105 auto node = distributed::cluster::ClusterContext::instance()->node();
106 MS_EXCEPTION_IF_NULL(node);
107 const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
108 MS_EXCEPTION_IF_NULL(cluster_ctx);
109 MS_LOG(INFO) << "Cluster is initialized. This node role is " << cluster_ctx->node_role();
110 return cluster_ctx->node_role();
111 }
112 #endif
113 return "";
114 }
115
GetCompileCachePath(size_t idx)116 std::string GetCompileCachePath(size_t idx) {
117 return GetGraphCacheDir() + "/" + GetRole() + kCompileCacheFileName + "_" + std::to_string(idx) + kMindIrSuffix;
118 }
119
GetBackendCompileCachePathWithoutExtension(size_t idx)120 std::string GetBackendCompileCachePathWithoutExtension(size_t idx) {
121 return GetGraphCacheDir() + "/" + GetRole() + kBackendCompileCacheFileName + "_" + std::to_string(idx);
122 }
123
GetDepFilesHashPath()124 std::string GetDepFilesHashPath() {
125 static const std::string dep_files_hash_path = GetGraphCacheDir() + "/" + GetRole() + kDepFilesHashPath;
126 return dep_files_hash_path;
127 }
128
GetGroupCkptSavePath(size_t index)129 std::string GetGroupCkptSavePath(size_t index) {
130 auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
131 if (!group_info_save_path.empty()) {
132 return group_info_save_path;
133 }
134 return GetGraphCacheDir() + "/group_" + std::to_string(index) + ".ckpt";
135 }
136
GetDataQueueNameCachePath(const std::string & data_queue_num)137 std::string GetDataQueueNameCachePath(const std::string &data_queue_num) {
138 std::string queue_name_cache_path =
139 GetGraphCacheDir() + "/" + GetRole() + "_" + data_queue_num + kDataQueueNameCacheFileName;
140 return queue_name_cache_path;
141 }
142
GetCompileDepFilesHash(const py::list & dep_files)143 std::string GetCompileDepFilesHash(const py::list &dep_files) {
144 MS_LOG(DEBUG) << "Dependency files size: " << dep_files.size();
145 std::vector<std::string> dep_files_path;
146 for (auto dep_file : dep_files) {
147 auto file_path = py::cast<std::string>(dep_file);
148 MS_LOG(DEBUG) << "Dependency file path: " << file_path;
149 (void)dep_files_path.emplace_back(file_path);
150 }
151 std::sort(dep_files_path.begin(), dep_files_path.end());
152 std::string files_hash;
153 for (const auto &path : dep_files_path) {
154 std::string file_hash = system::sha256::GetHashFromFile(path);
155 files_hash += file_hash;
156 }
157 std::string files_hash_hash = system::sha256::GetHashFromString(files_hash);
158 return files_hash_hash;
159 }
160
GenerateWeightsValueMap(const py::dict & weights)161 std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
162 std::map<string, ValuePtr> ret{};
163 for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
164 auto weight_name = py::cast<std::string>(weight->first);
165 auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
166 ret[weight_name] = weight_value;
167 }
168 return ret;
169 }
170
LoadFuncGraphFromMindIR(const py::dict & weights,bool has_parallel_info,size_t idx)171 std::pair<FuncGraphPtr, LayoutMap> LoadFuncGraphFromMindIR(const py::dict &weights, bool has_parallel_info,
172 size_t idx) {
173 LayoutMap layout_map;
174 std::string compile_cache_path = GetCompileCachePath(idx);
175 auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
176 if (!realpath.has_value()) {
177 MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed.";
178 return std::make_pair(nullptr, layout_map);
179 }
180 struct stat buffer;
181 if (stat(realpath.value().c_str(), &buffer) != 0) {
182 MS_LOG(WARNING) << "Open the compilation cache file " << realpath.value() << " failed.";
183 return std::make_pair(nullptr, layout_map);
184 }
185 auto ms_context = MsContext::GetInstance();
186 MS_EXCEPTION_IF_NULL(ms_context);
187 ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
188 MindIRLoader mindir_loader;
189 mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
190 mindir_loader.set_has_parallel_info(has_parallel_info);
191 mindspore::HashMap<std::string, AnfNodePtr> name_to_node;
192 auto fg = mindir_loader.LoadMindIR(realpath.value(), &name_to_node);
193 auto &context = CompileCacheContext::GetInstance();
194 context.SetFrontNameToFrontNode(name_to_node);
195 context.SetFrontGraph(fg);
196 context.InsertBackendGraphCachePath(fg, GetBackendCompileCachePathWithoutExtension(idx));
197
198 if (ms_context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
199 MS_LOG(INFO) << "Cell reuse(@lazy_inline) actually takes effect.";
200 }
201 #if defined(__linux__) && defined(WITH_BACKEND)
202 // compile cache does not support host collective or graph kernel.
203 if (common::UseHostCollective() || ms_context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)) {
204 context.SetRestrictedScenarios(true);
205 }
206 #endif
207 return std::make_pair(fg, mindir_loader.layout_map());
208 }
209
ExportFuncGraphToMindIR(const FuncGraphPtr & fg,const FuncGraphPtr & layout_fg,size_t idx)210 bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) {
211 std::string compile_cache_path = GetCompileCachePath(idx);
212 auto proto = GenBinaryProto(fg);
213 if (proto == nullptr) {
214 MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
215 return false;
216 }
217 #ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP
218 if (layout_fg) {
219 BuildLayout(layout_fg, proto.get());
220 }
221 #endif
222 auto &context = CompileCacheContext::GetInstance();
223 context.SetFrontGraph(fg);
224 context.InsertBackendGraphCachePath(fg, GetBackendCompileCachePathWithoutExtension(idx));
225 #if defined(__linux__) && defined(WITH_BACKEND)
226 // compile cache does not support host collective or graph kernel.
227 auto ms_context = MsContext::GetInstance();
228 MS_EXCEPTION_IF_NULL(ms_context);
229 if (common::UseHostCollective() || ms_context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)) {
230 context.SetRestrictedScenarios(true);
231 }
232 #endif
233 MindIRExporter mindir_exporter;
234 return mindir_exporter.SaveProtoToFile(proto.get(), compile_cache_path);
235 }
236
ExportDepFilesHash(const std::string & compile_cache_dep_files_hash)237 bool ExportDepFilesHash(const std::string &compile_cache_dep_files_hash) {
238 std::string dep_files_hash_path = GetDepFilesHashPath();
239 auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
240 if (!realpath.has_value()) {
241 MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
242 return false;
243 }
244
245 ChangeFileMode(realpath.value(), S_IWUSR);
246 std::ofstream fout(realpath.value());
247 if (!fout.is_open()) {
248 MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
249 return false;
250 }
251 fout << compile_cache_dep_files_hash;
252 fout.close();
253 ChangeFileMode(realpath.value(), S_IRUSR);
254 return true;
255 }
256
ExportDataQueueName(const std::string & dataset_phase,const string & queue_name)257 bool ExportDataQueueName(const std::string &dataset_phase, const string &queue_name) {
258 if (queue_name.empty()) {
259 MS_LOG(INFO) << "Export data queue name in dataset phase: " << dataset_phase << ", queue name: " << queue_name;
260 return true;
261 }
262 MS_LOG(INFO) << "Export data queue name in dataset phase: " << dataset_phase;
263 auto &context = CompileCacheContext::GetInstance();
264 context.set_has_cached_queue_name(true);
265 const auto &filename = GetDataQueueNameCachePath(std::to_string(CompileCacheManager::data_queue_num_));
266 MS_LOG(INFO) << "Export data queue name in file " << filename;
267 nlohmann::json name_json;
268 if (!Common::FileExists(filename)) {
269 name_json[dataset_phase] = queue_name;
270 return Common::SaveStringToFile(filename, name_json.dump());
271 }
272 std::ifstream json_fs(filename);
273 if (!json_fs.good()) {
274 return false;
275 }
276 try {
277 json_fs >> name_json;
278 json_fs.close();
279 } catch (std::exception &e) {
280 MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
281 json_fs.close();
282 std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
283 std::ifstream retry_tmp(filename);
284 if (!retry_tmp.good()) {
285 MS_LOG(EXCEPTION) << "Open json file: " << filename << " error.";
286 }
287 retry_tmp >> name_json;
288 retry_tmp.close();
289 }
290 name_json[dataset_phase] = queue_name;
291 return Common::SaveStringToFile(filename, name_json.dump());
292 }
293
CreateParallelGroupsByCkptFile(size_t index)294 bool CreateParallelGroupsByCkptFile(size_t index) {
295 const std::string group_ckpt_save_path = GetGroupCkptSavePath(index);
296 auto realpath = Common::CreatePrefixPath(group_ckpt_save_path, true);
297 if (!realpath.has_value()) {
298 MS_LOG(ERROR) << "Get real path of file " << group_ckpt_save_path << " failed.";
299 return false;
300 }
301 std::ifstream f(realpath.value());
302 bool file_is_good = f.good();
303 f.close();
304 if (!file_is_good) {
305 MS_LOG(ERROR) << "Open the group checkpoint file " << realpath.value() << " failed.";
306 return false;
307 }
308 return parallel::CreateGroupsByCkptFile(group_ckpt_save_path);
309 }
310
GetDataQueueName(const FuncGraphPtr & fg)311 std::string GetDataQueueName(const FuncGraphPtr &fg) {
312 auto cnodes = fg->GetOrderedCnodes();
313 std::string queue_name;
314 for (const auto &cnode : cnodes) {
315 auto prim = GetValuePtr<Primitive>(cnode->input(0));
316 if (prim != nullptr && prim->HasAttr("shared_name")) {
317 StringImmPtr queue_name_ptr = std::dynamic_pointer_cast<StringImm>(prim->GetAttr("shared_name"));
318 queue_name = queue_name_ptr->value();
319 break;
320 }
321 }
322 return queue_name;
323 }
324 } // namespace
325
326 size_t CompileCacheManager::data_queue_num_ = 0;
GetCachedDataQueueName(const std::string & dataset_phase)327 std::string CompileCacheManager::GetCachedDataQueueName(const std::string &dataset_phase) {
328 std::string queue_name;
329 if (!CompileCacheEnable()) {
330 return queue_name;
331 }
332 data_queue_num_++;
333 auto &config_mng = ConfigManager::GetInstance();
334 if (config_mng.dataset_phase().empty()) {
335 config_mng.set_dataset_phase(dataset_phase);
336 }
337 // if queue name has cached, we should not get it again from cache file in the same process.
338 auto &context = CompileCacheContext::GetInstance();
339 if (context.has_cached_queue_name()) {
340 return queue_name;
341 }
342 const auto &filename = GetDataQueueNameCachePath(std::to_string(CompileCacheManager::data_queue_num_));
343 MS_LOG(INFO) << "Get data queue name from file " << filename;
344 std::ifstream json_fs(filename);
345 if (!json_fs.good()) {
346 return queue_name;
347 }
348 nlohmann::json name_json;
349 try {
350 json_fs >> name_json;
351 json_fs.close();
352 } catch (std::exception &e) {
353 MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
354 json_fs.close();
355 std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
356 std::ifstream retry_tmp(filename);
357 if (!retry_tmp.good()) {
358 MS_LOG(EXCEPTION) << "Open json file: " << filename << " error.";
359 }
360 retry_tmp >> name_json;
361 retry_tmp.close();
362 }
363 queue_name = name_json[dataset_phase];
364 return queue_name;
365 }
366
CacheFuncGraph(const FuncGraphPtr & fg,const FuncGraphPtr & layout_fg)367 void CompileCacheManager::CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg) {
368 if (fg == nullptr) {
369 MS_LOG(ERROR) << "The func_graph to be cached is null.";
370 return;
371 }
372
373 const auto &queue_name = GetDataQueueName(fg);
374 auto dataset_phase = ConfigManager::GetInstance().dataset_phase();
375 if (!ExportDataQueueName(dataset_phase, queue_name)) {
376 MS_LOG(ERROR) << "Failed to cache data queue name: " << queue_name;
377 return;
378 }
379
380 SetCompileCacheDir(GetCompileCacheDir());
381
382 if (!ExportFuncGraphToMindIR(fg, layout_fg, compile_cache_id_)) {
383 MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString();
384 return;
385 }
386 if (compile_cache_id_ == 0 && !ExportDepFilesHash(compile_cache_dep_files_hash_)) {
387 MS_LOG(ERROR) << "Failed to cache the dependency files hash";
388 }
389 }
390
InitCompileCacheHash(const py::list & compile_cache_dep_files)391 void CompileCacheManager::InitCompileCacheHash(const py::list &compile_cache_dep_files) {
392 compile_cache_dep_files_hash_ = GetCompileDepFilesHash(compile_cache_dep_files);
393 auto &context = CompileCacheContext::GetInstance();
394 context.SetCompileCacheDepFilesHash(compile_cache_dep_files_hash_);
395 }
396
CanLoadCache()397 bool CompileCacheManager::CanLoadCache() {
398 if (compile_cache_dep_files_hash_.empty()) {
399 MS_LOG(ERROR) << "Get current dependency files hash failed.";
400 return false;
401 }
402 std::string dep_files_hash_path = GetDepFilesHashPath();
403 auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
404 if (!realpath.has_value()) {
405 MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
406 return false;
407 }
408 std::fstream input(realpath.value(), std::ios::in | std::ios::binary);
409 if (!input) {
410 MS_LOG(WARNING) << "Open the hash file " << realpath.value() << " failed. The file may not exist."
411 << ErrnoToString(errno);
412 return false;
413 }
414 std::string checkpoint_hash;
415 input >> checkpoint_hash;
416 if (checkpoint_hash.empty()) {
417 MS_LOG(ERROR) << "Get the compilation dependency files hash from " << realpath.value() << " failed.";
418 return false;
419 }
420 if (checkpoint_hash != compile_cache_dep_files_hash_) {
421 MS_LOG(WARNING) << "The compilation dependency files are changed.";
422 return false;
423 }
424 auto compile_cache_path = GetCompileCachePath(compile_cache_id_);
425 struct stat buffer;
426 if (stat(compile_cache_path.c_str(), &buffer) != 0) {
427 MS_LOG(WARNING) << "Failed to find cache file, execute all the compilation actions.";
428 return false;
429 }
430 return true;
431 }
432
GetCachedFuncGraph(const FuncGraphManagerPtr & manager,const py::dict & weights,const std::string & queue_name)433 FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights,
434 const std::string &queue_name) {
435 // Determine whether to load parallel information.
436 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
437 bool has_parallel_info = false;
438 if ((parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel)) {
439 if (!CreateParallelGroupsByCkptFile(compile_cache_id_)) {
440 MS_LOG(WARNING) << "Failed to create the parallel groups info. Execute all the compilation actions.";
441 return nullptr;
442 }
443 has_parallel_info = true;
444 }
445 // Load the compilation cache file.
446 auto pair = LoadFuncGraphFromMindIR(weights, has_parallel_info, compile_cache_id_);
447 if (pair.first == nullptr) {
448 MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
449 return nullptr;
450 }
451 auto fg = pair.first;
452 layout_map_ = pair.second;
453
454 MS_LOG(WARNING) << "Use the compilation cache and execute the backend actions only. Be aware of correctness risks.";
455 FuncGraphManagerPtr mng = fg->manager();
456 if (mng == nullptr) {
457 MS_EXCEPTION_IF_NULL(manager);
458 manager->AddFuncGraph(fg);
459 fg->set_manager(manager);
460 }
461 // The value of attr "shared_name" will changed every time.
462 auto cnodes = fg->GetOrderedCnodes();
463 for (const auto &cnode : cnodes) {
464 auto prim = GetValuePtr<Primitive>(cnode->input(0));
465 if (prim != nullptr && prim->HasAttr("shared_name")) {
466 prim->set_attr("shared_name", MakeValue(queue_name));
467 break;
468 }
469 }
470 #ifdef ENABLE_DUMP_IR
471 auto context = MsContext::GetInstance();
472 MS_EXCEPTION_IF_NULL(context);
473 if (context->CanDump(kIntroductory)) {
474 DumpIR("cache_loaded_graph_" + std::to_string(compile_cache_id_) + ".ir", fg);
475 }
476 #endif
477 return fg;
478 }
479
InitParallelGroupCkptSaveFile()480 void CompileCacheManager::InitParallelGroupCkptSaveFile() {
481 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
482 if ((parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel)) {
483 parallel::ParallelContext::GetInstance()->set_group_ckpt_save_file(GetGroupCkptSavePath(compile_cache_id_));
484 }
485 }
486 } // namespace pipeline
487 } // namespace mindspore
488