• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/recovery/recovery_context.h"
18 
19 #include <dirent.h>
20 #include <algorithm>
21 #include <utility>
22 #include <map>
23 
24 #include "nlohmann/json.hpp"
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "include/backend/distributed/ps/constants.h"
27 #include "utils/file_utils.h"
28 #include "include/backend/distributed/constants.h"
29 #include "distributed/persistent/storage/file_io_utils.h"
30 #include "distributed/persistent/storage/json_utils.h"
31 #include "include/backend/distributed/cluster/topology/common.h"
32 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
33 #include "include/backend/distributed/cluster/cluster_context.h"
34 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
35 #endif
36 #include "runtime/hardware/device_context_manager.h"
37 #include "utils/convert_utils_base.h"
38 #include "utils/ms_context.h"
39 
40 namespace mindspore {
41 namespace distributed {
42 namespace recovery {
43 constexpr char kCkptSuffix[] = ".ckpt";
44 constexpr char kCkptPath[] = "ckpt_path";
45 constexpr char kJsonSuffix[] = ".json";
46 constexpr char kConfigJson[] = "/config.json";
47 
48 const uint32_t kSendBufferLen = 2;
49 
50 constexpr char kCkptEpochInfoPrefix[] = "ckpt_epoch_rank_";
51 constexpr char kCkptStepInfoPrefix[] = "ckpt_step_rank_";
52 
53 namespace {
ParseCkptEpochStep(const std::string & checkpoint)54 std::pair<int, int> ParseCkptEpochStep(const std::string &checkpoint) {
55   size_t suffix_pos = checkpoint.rfind('.');
56   if (suffix_pos == std::string::npos || checkpoint.substr(suffix_pos) != kCkptSuffix) {
57     MS_LOG(WARNING) << "The file : " << checkpoint << "is not a checkpoint";
58     return {};
59   }
60 
61   size_t epoch_begin_pos = checkpoint.rfind('-');
62   size_t step_begin_pos = checkpoint.rfind('_');
63   if (epoch_begin_pos == std::string::npos || step_begin_pos == std::string::npos) {
64     MS_LOG(EXCEPTION) << "The checkpoint file name is not valid: " << checkpoint;
65   }
66 
67   return std::make_pair(std::stoi(checkpoint.substr(epoch_begin_pos + 1, (step_begin_pos - epoch_begin_pos) - 1)),
68                         std::stoi(checkpoint.substr(step_begin_pos + 1, (suffix_pos - step_begin_pos) - 1)));
69 }
70 
RemoveAllCkptFiles(const std::string & directory,const std::vector<std::string> & files_list)71 void RemoveAllCkptFiles(const std::string &directory, const std::vector<std::string> &files_list) {
72   for (size_t i = 0; i < files_list.size(); i++) {
73     const auto &ckpt_name = files_list[i];
74     const auto &ckpt_file = directory + "/" + ckpt_name;
75     (void)remove(ckpt_file.c_str());
76   }
77 }
78 }  // namespace
79 
IsEnableRecovery()80 bool IsEnableRecovery() { return common::GetEnv(kEnvEnableRecovery) == std::string("1"); }
81 
RecoveryPath()82 std::string RecoveryPath() { return common::GetEnv(kEnvRecoveryPath); }
83 
Initialize()84 void RecoveryContext::Initialize() {
85   if (initialized_) {
86     return;
87   }
88 
89   // 1. Read environment variable.
90   enable_recovery_ = IsEnableRecovery();
91   if (!enable_recovery_) {
92     return;
93   }
94 
95   auto context_ptr = MsContext::GetInstance();
96   MS_EXCEPTION_IF_NULL(context_ptr);
97   context_ptr->set_param<bool>(MS_CTX_ENABLE_RECOVERY, true);
98 
99   recovery_path_ = RecoveryPath();
100   if (recovery_path_.empty()) {
101     MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
102   }
103 
104   auto env_recovery_interval = common::GetEnv(kEnvRecoveryInterval);
105   if (!env_recovery_interval.empty()) {
106     recovery_interval_ = std::stoi(env_recovery_interval);
107   }
108 
109   node_role_ = common::GetEnv(distributed::kEnvRole);
110   if (distributed::kValidRoleName.count(node_role_) == 0) {
111     MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. ";
112   }
113 
114   // 2. Get real recovery path and create config file.
115   if (!storage::FileIOUtils::IsFileOrDirExist(recovery_path_)) {
116     storage::FileIOUtils::CreateDirRecursive(recovery_path_);
117   }
118 
119   auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
120   if (!ret.has_value()) {
121     MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
122   }
123   recovery_path_ = ret.value();
124 
125   std::string config_file_path = recovery_path_ + kConfigJson;
126   if (!storage::FileIOUtils::IsFileOrDirExist(config_file_path)) {
127     CreateConfigFile(config_file_path);
128   }
129 
130   // 3. Set config content to PSContext.
131   ps::PSContext::instance()->set_config_file_path(config_file_path);
132   ps::PSContext::instance()->set_node_id(common::GetEnv(distributed::cluster::topology::kEnvNodeId));
133 
134   initialized_ = true;
135 }
136 
ObtainGlobalLatestCkptInfo()137 void RecoveryContext::ObtainGlobalLatestCkptInfo() {
138   // 1. Obtain the step corresponding to the local latest checkpoint.
139   ObtainLocalLatestCkptInfo();
140 
141   // For standalone training.
142   if (global_rank_size_ == 0) {
143     return;
144   }
145 
146   // 2. AllGather the latest checkpoint info of all nodes.
147   device::DeviceContextKey host_key = {"CPU", 0};
148   device::DeviceContext *host_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
149   MS_EXCEPTION_IF_NULL(host_context);
150   MS_EXCEPTION_IF_NULL(host_context->device_res_manager_);
151   device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->device_res_manager_->collective_comm_lib();
152   MS_EXCEPTION_IF_NULL(host_comm_lib_instance);
153 
154   if (global_rank_id_ >= global_rank_size_) {
155     MS_LOG(EXCEPTION) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
156                       << global_rank_size_;
157   }
158 
159   const std::size_t kRecvBufferLen = kSendBufferLen * global_rank_size_;
160 
161   std::vector<int> recv_buffer(kRecvBufferLen, 0);
162 
163 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
164   // Synchronize the checkpoint information between all the other nodes to ensure the accuracy of training.
165   auto node = cluster::ClusterContext::instance()->node();
166   MS_EXCEPTION_IF_NULL(node);
167   auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(node);
168   MS_EXCEPTION_IF_NULL(cgn);
169 
170   // Start the ckpt file info exchange process.
171   std::map<std::string, std::string> results;
172   const std::string biz = "sync_ckpt";
173 
174   std::vector<std::string> names_prefix;
175   (void)names_prefix.emplace_back(kCkptEpochInfoPrefix);
176   (void)names_prefix.emplace_back(kCkptStepInfoPrefix);
177 
178   std::vector<std::string> values;
179   values.push_back(std::to_string(latest_ckpt_epoch_));
180   values.push_back(std::to_string(latest_ckpt_step_));
181 
182   if (cgn->ExchangeMetadata(biz, global_rank_size_, names_prefix, values, &results, INT_MAX)) {
183     for (uint32_t i = 0; i < global_rank_size_; ++i) {
184       auto epoch_key = kCkptEpochInfoPrefix + std::to_string(i);
185       auto step_key = kCkptStepInfoPrefix + std::to_string(i);
186       auto ckpt_epoch = results[epoch_key];
187       auto ckpt_step = results[step_key];
188       if (ckpt_epoch.length() > 0 && ckpt_step.length() > 0) {
189         recv_buffer[kSendBufferLen * i] = std::stoi(ckpt_epoch);
190         recv_buffer[kSendBufferLen * i + 1] = std::stoi(ckpt_step);
191         MS_LOG(INFO) << "The latest checkpoint for rank " << i << "is that epoch: " << ckpt_epoch
192                      << ", step: " << ckpt_step;
193       }
194     }
195     MS_LOG(INFO) << "The checkpoint information of all the ranks have been synchronized.";
196   }
197 #endif
198 
199   // 3. Check whether save checkpoint successfully on every workers.
200   uint32_t save_ckpt_success_num = 0;
201   uint32_t save_ckpt_failed_num = 0;
202   for (uint32_t i = 0; i < kRecvBufferLen; i += kSendBufferLen) {
203     if (recv_buffer[i] < 0) {
204       save_ckpt_failed_num++;
205     } else {
206       save_ckpt_success_num++;
207     }
208   }
209 
210   if (save_ckpt_success_num > 0 && save_ckpt_failed_num > 0) {
211     RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
212     MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
213   }
214   if (save_ckpt_success_num == 0 && save_ckpt_failed_num == global_rank_size_) {
215     return;
216   }
217 
218   // 4. Parse latest epoch and step info.
219   ParseLatestCkptInfo(recv_buffer);
220 
221   // 5. Remove useless ckpt
222   for (int i = SizeToInt(ckpt_files_.size()) - 1; i >= 0; i--) {
223     const auto &last_ckpt_name = ckpt_files_[IntToSize(i)];
224     const auto &last_ckpt_file = GetCkptPath() + "/" + last_ckpt_name;
225     if (last_ckpt_file != latest_ckpt_file_) {
226       (void)remove(last_ckpt_file.c_str());
227     } else {
228       break;
229     }
230   }
231 }
232 
ObtainLocalLatestCkptInfo()233 void RecoveryContext::ObtainLocalLatestCkptInfo() {
234   std::string ckpt_save_dir = GetCkptPath();
235   if (ckpt_save_dir.empty()) {
236     MS_LOG(INFO) << "The ckpt file path is empty";
237     return;
238   }
239 
240   DIR *dir = opendir(ckpt_save_dir.c_str());
241   if (dir == nullptr) {
242     MS_LOG(EXCEPTION) << "The file path [" << ckpt_save_dir << "] is not exist";
243   }
244 
245   if (!ckpt_files_.empty()) {
246     ckpt_files_.clear();
247   }
248 
249   struct dirent *entry;
250   while ((entry = readdir(dir)) != nullptr) {
251     std::string file_name = entry->d_name;
252     size_t suffix_pos = file_name.rfind('.');
253     if (suffix_pos == std::string::npos || file_name.substr(suffix_pos) != kCkptSuffix) {
254       continue;
255     }
256 
257     ckpt_files_.push_back(file_name);
258   }
259   (void)closedir(dir);
260 
261   if (ckpt_files_.empty()) {
262     MS_LOG(INFO) << "There is no checkpoint file in dir: " << ckpt_save_dir;
263     return;
264   }
265 
266   sort(ckpt_files_.begin(), ckpt_files_.end(), [](const std::string &a, const std::string &b) {
267     auto ckpt_epoch_step_a = ParseCkptEpochStep(a);
268     auto ckpt_epoch_step_b = ParseCkptEpochStep(b);
269     if (ckpt_epoch_step_a.first < ckpt_epoch_step_b.first) {
270       return true;
271     } else if (ckpt_epoch_step_a.first == ckpt_epoch_step_b.first) {
272       return ckpt_epoch_step_a.second < ckpt_epoch_step_b.second;
273     } else {
274       return false;
275     }
276   });
277 
278   const auto &latest_ckpt_name = ckpt_files_.back();
279   latest_ckpt_file_ = ckpt_save_dir + "/" + latest_ckpt_name;
280 
281   auto ckpt_epoch_step = ParseCkptEpochStep(latest_ckpt_name);
282   latest_ckpt_epoch_ = ckpt_epoch_step.first;
283   latest_ckpt_step_ = ckpt_epoch_step.second;
284 }
285 
ParseLatestCkptInfo(const std::vector<int> & recv_buffer)286 void RecoveryContext::ParseLatestCkptInfo(const std::vector<int> &recv_buffer) {
287   std::vector<std::pair<int, int>> ckpts_epoch_step;
288   for (std::size_t i = 0; i + 1 < recv_buffer.size(); i += kSendBufferLen) {
289     (void)ckpts_epoch_step.emplace_back(recv_buffer[i], recv_buffer[i + 1]);
290   }
291   if (ckpts_epoch_step.empty()) {
292     MS_LOG(EXCEPTION) << "Ckpts received is empty.";
293   }
294   sort(ckpts_epoch_step.begin(), ckpts_epoch_step.end(),
295        [](const std::pair<int, int> &a, const std::pair<int, int> &b) {
296          if (a.first < b.first) {
297            return true;
298          } else if (a.first == b.first) {
299            return a.second < b.second;
300          } else {
301            return false;
302          }
303        });
304 
305   const std::pair<int, int> &latest_epoch_step = ckpts_epoch_step.front();
306   latest_ckpt_epoch_ = latest_epoch_step.first;
307   latest_ckpt_step_ = latest_epoch_step.second;
308 
309   const std::string latest_epoch_step_suffix =
310     std::to_string(latest_epoch_step.first) + "_" + std::to_string(latest_epoch_step.second) + kCkptSuffix;
311   auto iter = std::find_if(ckpt_files_.rbegin(), ckpt_files_.rend(), [&](const std::string &file_name) {
312     if (file_name.size() <= latest_epoch_step_suffix.size()) {
313       return false;
314     }
315     return file_name.rfind(latest_epoch_step_suffix) == (file_name.size() - latest_epoch_step_suffix.size());
316   });
317   if (iter == ckpt_files_.rend()) {
318     RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
319     MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
320   }
321 
322   latest_ckpt_file_ = GetCkptPath() + "/" + *iter;
323 }
324 
CreateConfigFile(const std::string & config_file_path)325 void RecoveryContext::CreateConfigFile(const std::string &config_file_path) {
326   if (storage::FileIOUtils::IsFileOrDirExist(config_file_path)) {
327     MS_LOG(WARNING) << "The config file exists, file path: " << config_file_path;
328     return;
329   }
330 
331   int fd = open(config_file_path.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR);
332   if (fd == -1) {
333     if (errno != EEXIST) {
334       MS_LOG(EXCEPTION) << "Create config file: [" << config_file_path << "] failed, errno: " << errno << ", "
335                         << strerror(errno);
336     }
337     MS_LOG(INFO) << "The config file is already created, file path: " << config_file_path;
338   } else {
339     // Create config file.
340     nlohmann::json config_js;
341     config_js[std::string(ps::kStoreType)] = 1;
342     config_js[std::string(ps::kStoreFilePath)] = recovery_path_ + "/" + ps::kStoreFilePath + kJsonSuffix;
343     config_js[std::string(ps::kSchedulerStoreFilePath)] =
344       recovery_path_ + "/" + ps::kSchedulerStoreFilePath + kJsonSuffix;
345 
346     nlohmann::json recovery_js;
347     recovery_js[std::string(ps::kKeyRecovery)] = config_js;
348 
349     std::string config_content = recovery_js.dump();
350     auto ret_size = write(fd, config_content.c_str(), config_content.size());
351     if (ret_size != SizeToLong(config_content.size())) {
352       (void)close(fd);
353       errno_t err = (ret_size == 0) ? EOF : errno;
354       MS_LOG(EXCEPTION) << "Write config file: [" << config_file_path << "] failed, errno: " << err << ", "
355                         << strerror(err);
356     }
357     (void)close(fd);
358   }
359 }
360 
CreatePersistentFile()361 void RecoveryContext::CreatePersistentFile() {
362   std::unique_lock<std::mutex> lock(create_persist_json_mtx_);
363   if (node_role_ == distributed::kEnvRoleOfScheduler) {
364     return;
365   }
366 
367   if (persistent_json_ != nullptr) {
368     return;
369   }
370 
371   // Need to get real path of recovry path for worker or server.
372   auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
373   if (!ret.has_value()) {
374     MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
375   }
376   recovery_path_ = ret.value();
377 
378   // The directory used to save ckpt is persisted to json file.
379   std::string persistent_file_path =
380     recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
381   persistent_json_ = std::make_shared<storage::JsonUtils>(persistent_file_path);
382   if (!persistent_json_->Initialize()) {
383     MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
384   }
385 }
386 
SetCkptPath(const std::string & path)387 void RecoveryContext::SetCkptPath(const std::string &path) {
388   if (node_role_ == distributed::kEnvRoleOfScheduler) {
389     return;
390   }
391 
392   if (!storage::FileIOUtils::IsFileOrDirExist(path)) {
393     storage::FileIOUtils::CreateDirRecursive(path);
394   }
395 
396   auto ret = FileUtils::GetRealPath(path.c_str());
397   if (!ret.has_value()) {
398     MS_LOG(EXCEPTION) << "Cannot get real path for save checkpoint, path: " << path;
399   }
400 
401   if (persistent_json_ == nullptr) {
402     CreatePersistentFile();
403   }
404 
405   MS_EXCEPTION_IF_NULL(persistent_json_);
406   persistent_json_->Insert(kCkptPath, ret.value());
407 }
408 
GetCkptPath()409 std::string RecoveryContext::GetCkptPath() {
410   if (node_role_ == distributed::kEnvRoleOfScheduler) {
411     return std::string();
412   }
413 
414   if (persistent_json_ == nullptr) {
415     CreatePersistentFile();
416   }
417 
418   MS_EXCEPTION_IF_NULL(persistent_json_);
419   if (!persistent_json_->Exists(kCkptPath)) {
420     return std::string();
421   }
422 
423   return persistent_json_->Get<std::string>(kCkptPath);
424 }
425 
persistent_json()426 const std::shared_ptr<storage::JsonUtils> &RecoveryContext::persistent_json() {
427   if (persistent_json_ == nullptr) {
428     CreatePersistentFile();
429   }
430 
431   MS_EXCEPTION_IF_NULL(persistent_json_);
432   return persistent_json_;
433 }
434 
latest_ckpt_file()435 std::string RecoveryContext::latest_ckpt_file() {
436   // For standalone training.
437   if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) {
438     ObtainLocalLatestCkptInfo();
439   }
440 
441   return latest_ckpt_file_;
442 }
443 }  // namespace recovery
444 }  // namespace distributed
445 }  // namespace mindspore
446