1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/recovery/recovery_context.h"
18
19 #include <dirent.h>
20 #include <algorithm>
21 #include <utility>
22 #include <map>
23
24 #include "nlohmann/json.hpp"
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "include/backend/distributed/ps/constants.h"
27 #include "utils/file_utils.h"
28 #include "include/backend/distributed/constants.h"
29 #include "distributed/persistent/storage/file_io_utils.h"
30 #include "distributed/persistent/storage/json_utils.h"
31 #include "include/backend/distributed/cluster/topology/common.h"
32 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
33 #include "include/backend/distributed/cluster/cluster_context.h"
34 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
35 #endif
36 #include "runtime/hardware/device_context_manager.h"
37 #include "utils/convert_utils_base.h"
38 #include "utils/ms_context.h"
39
40 namespace mindspore {
41 namespace distributed {
42 namespace recovery {
43 constexpr char kCkptSuffix[] = ".ckpt";
44 constexpr char kCkptPath[] = "ckpt_path";
45 constexpr char kJsonSuffix[] = ".json";
46 constexpr char kConfigJson[] = "/config.json";
47
48 const uint32_t kSendBufferLen = 2;
49
50 constexpr char kCkptEpochInfoPrefix[] = "ckpt_epoch_rank_";
51 constexpr char kCkptStepInfoPrefix[] = "ckpt_step_rank_";
52
53 namespace {
ParseCkptEpochStep(const std::string & checkpoint)54 std::pair<int, int> ParseCkptEpochStep(const std::string &checkpoint) {
55 size_t suffix_pos = checkpoint.rfind('.');
56 if (suffix_pos == std::string::npos || checkpoint.substr(suffix_pos) != kCkptSuffix) {
57 MS_LOG(WARNING) << "The file : " << checkpoint << "is not a checkpoint";
58 return {};
59 }
60
61 size_t epoch_begin_pos = checkpoint.rfind('-');
62 size_t step_begin_pos = checkpoint.rfind('_');
63 if (epoch_begin_pos == std::string::npos || step_begin_pos == std::string::npos) {
64 MS_LOG(EXCEPTION) << "The checkpoint file name is not valid: " << checkpoint;
65 }
66
67 return std::make_pair(std::stoi(checkpoint.substr(epoch_begin_pos + 1, (step_begin_pos - epoch_begin_pos) - 1)),
68 std::stoi(checkpoint.substr(step_begin_pos + 1, (suffix_pos - step_begin_pos) - 1)));
69 }
70
RemoveAllCkptFiles(const std::string & directory,const std::vector<std::string> & files_list)71 void RemoveAllCkptFiles(const std::string &directory, const std::vector<std::string> &files_list) {
72 for (size_t i = 0; i < files_list.size(); i++) {
73 const auto &ckpt_name = files_list[i];
74 const auto &ckpt_file = directory + "/" + ckpt_name;
75 (void)remove(ckpt_file.c_str());
76 }
77 }
78 } // namespace
79
IsEnableRecovery()80 bool IsEnableRecovery() { return common::GetEnv(kEnvEnableRecovery) == std::string("1"); }
81
RecoveryPath()82 std::string RecoveryPath() { return common::GetEnv(kEnvRecoveryPath); }
83
Initialize()84 void RecoveryContext::Initialize() {
85 if (initialized_) {
86 return;
87 }
88
89 // 1. Read environment variable.
90 enable_recovery_ = IsEnableRecovery();
91 if (!enable_recovery_) {
92 return;
93 }
94
95 auto context_ptr = MsContext::GetInstance();
96 MS_EXCEPTION_IF_NULL(context_ptr);
97 context_ptr->set_param<bool>(MS_CTX_ENABLE_RECOVERY, true);
98
99 recovery_path_ = RecoveryPath();
100 if (recovery_path_.empty()) {
101 MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
102 }
103
104 auto env_recovery_interval = common::GetEnv(kEnvRecoveryInterval);
105 if (!env_recovery_interval.empty()) {
106 recovery_interval_ = std::stoi(env_recovery_interval);
107 }
108
109 node_role_ = common::GetEnv(distributed::kEnvRole);
110 if (distributed::kValidRoleName.count(node_role_) == 0) {
111 MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. ";
112 }
113
114 // 2. Get real recovery path and create config file.
115 if (!storage::FileIOUtils::IsFileOrDirExist(recovery_path_)) {
116 storage::FileIOUtils::CreateDirRecursive(recovery_path_);
117 }
118
119 auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
120 if (!ret.has_value()) {
121 MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
122 }
123 recovery_path_ = ret.value();
124
125 std::string config_file_path = recovery_path_ + kConfigJson;
126 if (!storage::FileIOUtils::IsFileOrDirExist(config_file_path)) {
127 CreateConfigFile(config_file_path);
128 }
129
130 // 3. Set config content to PSContext.
131 ps::PSContext::instance()->set_config_file_path(config_file_path);
132 ps::PSContext::instance()->set_node_id(common::GetEnv(distributed::cluster::topology::kEnvNodeId));
133
134 initialized_ = true;
135 }
136
ObtainGlobalLatestCkptInfo()137 void RecoveryContext::ObtainGlobalLatestCkptInfo() {
138 // 1. Obtain the step corresponding to the local latest checkpoint.
139 ObtainLocalLatestCkptInfo();
140
141 // For standalone training.
142 if (global_rank_size_ == 0) {
143 return;
144 }
145
146 // 2. AllGather the latest checkpoint info of all nodes.
147 device::DeviceContextKey host_key = {"CPU", 0};
148 device::DeviceContext *host_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
149 MS_EXCEPTION_IF_NULL(host_context);
150 MS_EXCEPTION_IF_NULL(host_context->device_res_manager_);
151 device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->device_res_manager_->collective_comm_lib();
152 MS_EXCEPTION_IF_NULL(host_comm_lib_instance);
153
154 if (global_rank_id_ >= global_rank_size_) {
155 MS_LOG(EXCEPTION) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
156 << global_rank_size_;
157 }
158
159 const std::size_t kRecvBufferLen = kSendBufferLen * global_rank_size_;
160
161 std::vector<int> recv_buffer(kRecvBufferLen, 0);
162
163 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
164 // Synchronize the checkpoint information between all the other nodes to ensure the accuracy of training.
165 auto node = cluster::ClusterContext::instance()->node();
166 MS_EXCEPTION_IF_NULL(node);
167 auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(node);
168 MS_EXCEPTION_IF_NULL(cgn);
169
170 // Start the ckpt file info exchange process.
171 std::map<std::string, std::string> results;
172 const std::string biz = "sync_ckpt";
173
174 std::vector<std::string> names_prefix;
175 (void)names_prefix.emplace_back(kCkptEpochInfoPrefix);
176 (void)names_prefix.emplace_back(kCkptStepInfoPrefix);
177
178 std::vector<std::string> values;
179 values.push_back(std::to_string(latest_ckpt_epoch_));
180 values.push_back(std::to_string(latest_ckpt_step_));
181
182 if (cgn->ExchangeMetadata(biz, global_rank_size_, names_prefix, values, &results, INT_MAX)) {
183 for (uint32_t i = 0; i < global_rank_size_; ++i) {
184 auto epoch_key = kCkptEpochInfoPrefix + std::to_string(i);
185 auto step_key = kCkptStepInfoPrefix + std::to_string(i);
186 auto ckpt_epoch = results[epoch_key];
187 auto ckpt_step = results[step_key];
188 if (ckpt_epoch.length() > 0 && ckpt_step.length() > 0) {
189 recv_buffer[kSendBufferLen * i] = std::stoi(ckpt_epoch);
190 recv_buffer[kSendBufferLen * i + 1] = std::stoi(ckpt_step);
191 MS_LOG(INFO) << "The latest checkpoint for rank " << i << "is that epoch: " << ckpt_epoch
192 << ", step: " << ckpt_step;
193 }
194 }
195 MS_LOG(INFO) << "The checkpoint information of all the ranks have been synchronized.";
196 }
197 #endif
198
199 // 3. Check whether save checkpoint successfully on every workers.
200 uint32_t save_ckpt_success_num = 0;
201 uint32_t save_ckpt_failed_num = 0;
202 for (uint32_t i = 0; i < kRecvBufferLen; i += kSendBufferLen) {
203 if (recv_buffer[i] < 0) {
204 save_ckpt_failed_num++;
205 } else {
206 save_ckpt_success_num++;
207 }
208 }
209
210 if (save_ckpt_success_num > 0 && save_ckpt_failed_num > 0) {
211 RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
212 MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
213 }
214 if (save_ckpt_success_num == 0 && save_ckpt_failed_num == global_rank_size_) {
215 return;
216 }
217
218 // 4. Parse latest epoch and step info.
219 ParseLatestCkptInfo(recv_buffer);
220
221 // 5. Remove useless ckpt
222 for (int i = SizeToInt(ckpt_files_.size()) - 1; i >= 0; i--) {
223 const auto &last_ckpt_name = ckpt_files_[IntToSize(i)];
224 const auto &last_ckpt_file = GetCkptPath() + "/" + last_ckpt_name;
225 if (last_ckpt_file != latest_ckpt_file_) {
226 (void)remove(last_ckpt_file.c_str());
227 } else {
228 break;
229 }
230 }
231 }
232
ObtainLocalLatestCkptInfo()233 void RecoveryContext::ObtainLocalLatestCkptInfo() {
234 std::string ckpt_save_dir = GetCkptPath();
235 if (ckpt_save_dir.empty()) {
236 MS_LOG(INFO) << "The ckpt file path is empty";
237 return;
238 }
239
240 DIR *dir = opendir(ckpt_save_dir.c_str());
241 if (dir == nullptr) {
242 MS_LOG(EXCEPTION) << "The file path [" << ckpt_save_dir << "] is not exist";
243 }
244
245 if (!ckpt_files_.empty()) {
246 ckpt_files_.clear();
247 }
248
249 struct dirent *entry;
250 while ((entry = readdir(dir)) != nullptr) {
251 std::string file_name = entry->d_name;
252 size_t suffix_pos = file_name.rfind('.');
253 if (suffix_pos == std::string::npos || file_name.substr(suffix_pos) != kCkptSuffix) {
254 continue;
255 }
256
257 ckpt_files_.push_back(file_name);
258 }
259 (void)closedir(dir);
260
261 if (ckpt_files_.empty()) {
262 MS_LOG(INFO) << "There is no checkpoint file in dir: " << ckpt_save_dir;
263 return;
264 }
265
266 sort(ckpt_files_.begin(), ckpt_files_.end(), [](const std::string &a, const std::string &b) {
267 auto ckpt_epoch_step_a = ParseCkptEpochStep(a);
268 auto ckpt_epoch_step_b = ParseCkptEpochStep(b);
269 if (ckpt_epoch_step_a.first < ckpt_epoch_step_b.first) {
270 return true;
271 } else if (ckpt_epoch_step_a.first == ckpt_epoch_step_b.first) {
272 return ckpt_epoch_step_a.second < ckpt_epoch_step_b.second;
273 } else {
274 return false;
275 }
276 });
277
278 const auto &latest_ckpt_name = ckpt_files_.back();
279 latest_ckpt_file_ = ckpt_save_dir + "/" + latest_ckpt_name;
280
281 auto ckpt_epoch_step = ParseCkptEpochStep(latest_ckpt_name);
282 latest_ckpt_epoch_ = ckpt_epoch_step.first;
283 latest_ckpt_step_ = ckpt_epoch_step.second;
284 }
285
ParseLatestCkptInfo(const std::vector<int> & recv_buffer)286 void RecoveryContext::ParseLatestCkptInfo(const std::vector<int> &recv_buffer) {
287 std::vector<std::pair<int, int>> ckpts_epoch_step;
288 for (std::size_t i = 0; i + 1 < recv_buffer.size(); i += kSendBufferLen) {
289 (void)ckpts_epoch_step.emplace_back(recv_buffer[i], recv_buffer[i + 1]);
290 }
291 if (ckpts_epoch_step.empty()) {
292 MS_LOG(EXCEPTION) << "Ckpts received is empty.";
293 }
294 sort(ckpts_epoch_step.begin(), ckpts_epoch_step.end(),
295 [](const std::pair<int, int> &a, const std::pair<int, int> &b) {
296 if (a.first < b.first) {
297 return true;
298 } else if (a.first == b.first) {
299 return a.second < b.second;
300 } else {
301 return false;
302 }
303 });
304
305 const std::pair<int, int> &latest_epoch_step = ckpts_epoch_step.front();
306 latest_ckpt_epoch_ = latest_epoch_step.first;
307 latest_ckpt_step_ = latest_epoch_step.second;
308
309 const std::string latest_epoch_step_suffix =
310 std::to_string(latest_epoch_step.first) + "_" + std::to_string(latest_epoch_step.second) + kCkptSuffix;
311 auto iter = std::find_if(ckpt_files_.rbegin(), ckpt_files_.rend(), [&](const std::string &file_name) {
312 if (file_name.size() <= latest_epoch_step_suffix.size()) {
313 return false;
314 }
315 return file_name.rfind(latest_epoch_step_suffix) == (file_name.size() - latest_epoch_step_suffix.size());
316 });
317 if (iter == ckpt_files_.rend()) {
318 RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
319 MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
320 }
321
322 latest_ckpt_file_ = GetCkptPath() + "/" + *iter;
323 }
324
CreateConfigFile(const std::string & config_file_path)325 void RecoveryContext::CreateConfigFile(const std::string &config_file_path) {
326 if (storage::FileIOUtils::IsFileOrDirExist(config_file_path)) {
327 MS_LOG(WARNING) << "The config file exists, file path: " << config_file_path;
328 return;
329 }
330
331 int fd = open(config_file_path.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR);
332 if (fd == -1) {
333 if (errno != EEXIST) {
334 MS_LOG(EXCEPTION) << "Create config file: [" << config_file_path << "] failed, errno: " << errno << ", "
335 << strerror(errno);
336 }
337 MS_LOG(INFO) << "The config file is already created, file path: " << config_file_path;
338 } else {
339 // Create config file.
340 nlohmann::json config_js;
341 config_js[std::string(ps::kStoreType)] = 1;
342 config_js[std::string(ps::kStoreFilePath)] = recovery_path_ + "/" + ps::kStoreFilePath + kJsonSuffix;
343 config_js[std::string(ps::kSchedulerStoreFilePath)] =
344 recovery_path_ + "/" + ps::kSchedulerStoreFilePath + kJsonSuffix;
345
346 nlohmann::json recovery_js;
347 recovery_js[std::string(ps::kKeyRecovery)] = config_js;
348
349 std::string config_content = recovery_js.dump();
350 auto ret_size = write(fd, config_content.c_str(), config_content.size());
351 if (ret_size != SizeToLong(config_content.size())) {
352 (void)close(fd);
353 errno_t err = (ret_size == 0) ? EOF : errno;
354 MS_LOG(EXCEPTION) << "Write config file: [" << config_file_path << "] failed, errno: " << err << ", "
355 << strerror(err);
356 }
357 (void)close(fd);
358 }
359 }
360
CreatePersistentFile()361 void RecoveryContext::CreatePersistentFile() {
362 std::unique_lock<std::mutex> lock(create_persist_json_mtx_);
363 if (node_role_ == distributed::kEnvRoleOfScheduler) {
364 return;
365 }
366
367 if (persistent_json_ != nullptr) {
368 return;
369 }
370
371 // Need to get real path of recovry path for worker or server.
372 auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
373 if (!ret.has_value()) {
374 MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
375 }
376 recovery_path_ = ret.value();
377
378 // The directory used to save ckpt is persisted to json file.
379 std::string persistent_file_path =
380 recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
381 persistent_json_ = std::make_shared<storage::JsonUtils>(persistent_file_path);
382 if (!persistent_json_->Initialize()) {
383 MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
384 }
385 }
386
SetCkptPath(const std::string & path)387 void RecoveryContext::SetCkptPath(const std::string &path) {
388 if (node_role_ == distributed::kEnvRoleOfScheduler) {
389 return;
390 }
391
392 if (!storage::FileIOUtils::IsFileOrDirExist(path)) {
393 storage::FileIOUtils::CreateDirRecursive(path);
394 }
395
396 auto ret = FileUtils::GetRealPath(path.c_str());
397 if (!ret.has_value()) {
398 MS_LOG(EXCEPTION) << "Cannot get real path for save checkpoint, path: " << path;
399 }
400
401 if (persistent_json_ == nullptr) {
402 CreatePersistentFile();
403 }
404
405 MS_EXCEPTION_IF_NULL(persistent_json_);
406 persistent_json_->Insert(kCkptPath, ret.value());
407 }
408
GetCkptPath()409 std::string RecoveryContext::GetCkptPath() {
410 if (node_role_ == distributed::kEnvRoleOfScheduler) {
411 return std::string();
412 }
413
414 if (persistent_json_ == nullptr) {
415 CreatePersistentFile();
416 }
417
418 MS_EXCEPTION_IF_NULL(persistent_json_);
419 if (!persistent_json_->Exists(kCkptPath)) {
420 return std::string();
421 }
422
423 return persistent_json_->Get<std::string>(kCkptPath);
424 }
425
persistent_json()426 const std::shared_ptr<storage::JsonUtils> &RecoveryContext::persistent_json() {
427 if (persistent_json_ == nullptr) {
428 CreatePersistentFile();
429 }
430
431 MS_EXCEPTION_IF_NULL(persistent_json_);
432 return persistent_json_;
433 }
434
latest_ckpt_file()435 std::string RecoveryContext::latest_ckpt_file() {
436 // For standalone training.
437 if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) {
438 ObtainLocalLatestCkptInfo();
439 }
440
441 return latest_ckpt_file_;
442 }
443 } // namespace recovery
444 } // namespace distributed
445 } // namespace mindspore
446