1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <mutex> 24 #include "utils/ms_utils.h" 25 #include "runtime/collective/collective_communication_lib.h" 26 #include "include/backend/visible.h" 27 28 namespace mindspore { 29 namespace distributed { 30 namespace storage { 31 class FileIOUtils; 32 class JsonUtils; 33 } // namespace storage 34 namespace recovery { 35 constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY"; 36 constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH"; 37 constexpr char kEnvRecoveryInterval[] = "MS_RECOVERY_INTERVAL"; 38 39 bool IsEnableRecovery(); 40 std::string RecoveryPath(); 41 42 // Used to save disaster recovery-related state quantities and provide disaster recovery-related 43 // functions, such as reinitializing collective communication, etc. 44 class BACKEND_EXPORT RecoveryContext { 45 public: GetInstance()46 static std::shared_ptr<RecoveryContext> &GetInstance() { 47 if (instance_ == nullptr) { 48 instance_.reset(new (std::nothrow) RecoveryContext()); 49 MS_EXCEPTION_IF_NULL(instance_); 50 instance_->Initialize(); 51 } 52 return instance_; 53 } 54 ~RecoveryContext() = default; 55 56 // Get whether enable recovery or not. enable_recovery()57 bool enable_recovery() const { return enable_recovery_; } 58 59 // Get the persistent directory. recovery_path()60 const std::string &recovery_path() const { return recovery_path_; } 61 62 // Get interval to persist model. recovery_interval()63 int recovery_interval() const { return recovery_interval_; } 64 65 // Set the path used to save checkpoint. 66 void SetCkptPath(const std::string &path); 67 // Get the path used to save checkpoint. 68 std::string GetCkptPath(); 69 70 // Get the latest checkpoint in this node. 71 std::string latest_ckpt_file(); 72 73 // Get the epoch of latest checkpoint in this node. latest_ckpt_epoch()74 int latest_ckpt_epoch() const { return latest_ckpt_epoch_; } 75 // Get the step of latest checkpoint in this node. latest_ckpt_step()76 int latest_ckpt_step() const { return latest_ckpt_step_; } 77 78 // Set whether need to reset training process or not, if true, all training process need to rollback the same step of 79 // latest checkpoint, including loading checkpoint and reset the minddata. set_need_reset(bool need_reset)80 void set_need_reset(bool need_reset) { need_reset_ = need_reset; } 81 // Get whether need to reset training process or not. need_reset()82 bool need_reset() const { return need_reset_; } 83 84 // Set whether need to sync the weight of model to device. set_need_sync_weight_to_device(bool need_sync_weight_to_device)85 void set_need_sync_weight_to_device(bool need_sync_weight_to_device) { 86 need_sync_weight_to_device_ = need_sync_weight_to_device; 87 } 88 // Get whether need to sync the weight of model to device or not. need_sync_weight_to_device()89 bool need_sync_weight_to_device() const { return need_sync_weight_to_device_; } 90 91 // Set global rank id. set_global_rank_id(uint32_t global_rank_id)92 void set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; } 93 // Set global rank size. set_global_rank_size(uint32_t global_rank_size)94 void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; } 95 96 // Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be 97 // some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the 98 // successful checkpoint in each training process, and then take the minimum value as the final reset position. 99 void ObtainGlobalLatestCkptInfo(); 100 101 // Get the persistent json file pointer. 102 const std::shared_ptr<storage::JsonUtils> &persistent_json(); 103 104 private: 105 inline static std::shared_ptr<RecoveryContext> instance_{}; 106 107 RecoveryContext() = default; 108 DISABLE_COPY_AND_ASSIGN(RecoveryContext); 109 110 // Initialize recovery context. 111 void Initialize(); 112 113 // Create config json file, used to persist node info of cluster. 114 void CreateConfigFile(const std::string &config_file_path); 115 116 // Create persitent json file, used to persist recovery config of Worker, such as ckpt path. 117 void CreatePersistentFile(); 118 119 // Obtain the step corresponding to the local latest checkpoint in each training process. 120 void ObtainLocalLatestCkptInfo(); 121 122 // Parse latest epoch and step info from all latest checkpoints info allgather from other workers. 123 void ParseLatestCkptInfo(const std::vector<int> &recv_buffer); 124 125 // Whether enable recovery or not, set by environment variable 'MS_ENABLE_RECOVERY'. 126 bool enable_recovery_{false}; 127 128 // The persistent directory, set by environment variable 'MS_RECOVERY_PATH'. 129 std::string recovery_path_; 130 131 // The interval to persist model, default value: 30 second. set by environment variable 'MS_RECOVERY_INTERVAL'. 132 int recovery_interval_{30}; 133 134 // Local checkpoint file list. 135 std::vector<std::string> ckpt_files_; 136 // The file name of latest checkpoint. 137 std::string latest_ckpt_file_; 138 // The epoch of latest checkpoint. 139 int latest_ckpt_epoch_{-1}; 140 // The step of latest checkpoint. 141 int latest_ckpt_step_{-1}; 142 143 // Node role in cluster, could be 'MS_WORKER', 'MS_SERVER' or 'MS_SCHED'. 144 std::string node_role_; 145 146 // The global rank id of this process. Normally this range is 0 to `global_rank_size_ - 1`. 147 uint32_t global_rank_id_{0}; 148 // The global rank size. 149 uint32_t global_rank_size_{0}; 150 151 // Whether need to reset training process or not. 152 bool need_reset_{false}; 153 154 // Whether need to sync the weight of model to device, this value needs to be set to true when python layer 155 // performs load checkpoint. 156 bool need_sync_weight_to_device_{false}; 157 158 // Whether the recovery context is already initialized. 159 bool initialized_{false}; 160 161 std::mutex create_persist_json_mtx_; 162 // The persitent json file util, used to persist recovery config. 163 std::shared_ptr<storage::JsonUtils> persistent_json_; 164 }; 165 } // namespace recovery 166 } // namespace distributed 167 } // namespace mindspore 168 #endif // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ 169