• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <mutex>
24 #include "utils/ms_utils.h"
25 #include "runtime/collective/collective_communication_lib.h"
26 #include "include/backend/visible.h"
27 
28 namespace mindspore {
29 namespace distributed {
30 namespace storage {
31 class FileIOUtils;
32 class JsonUtils;
33 }  // namespace storage
34 namespace recovery {
35 constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY";
36 constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH";
37 constexpr char kEnvRecoveryInterval[] = "MS_RECOVERY_INTERVAL";
38 
39 bool IsEnableRecovery();
40 std::string RecoveryPath();
41 
42 // Used to save disaster recovery-related state quantities and provide disaster recovery-related
43 // functions, such as reinitializing collective communication, etc.
44 class BACKEND_EXPORT RecoveryContext {
45  public:
GetInstance()46   static std::shared_ptr<RecoveryContext> &GetInstance() {
47     if (instance_ == nullptr) {
48       instance_.reset(new (std::nothrow) RecoveryContext());
49       MS_EXCEPTION_IF_NULL(instance_);
50       instance_->Initialize();
51     }
52     return instance_;
53   }
54   ~RecoveryContext() = default;
55 
56   // Get whether enable recovery or not.
enable_recovery()57   bool enable_recovery() const { return enable_recovery_; }
58 
59   // Get the persistent directory.
recovery_path()60   const std::string &recovery_path() const { return recovery_path_; }
61 
62   // Get interval to persist model.
recovery_interval()63   int recovery_interval() const { return recovery_interval_; }
64 
65   // Set the path used to save checkpoint.
66   void SetCkptPath(const std::string &path);
67   // Get the path used to save checkpoint.
68   std::string GetCkptPath();
69 
70   // Get the latest checkpoint in this node.
71   std::string latest_ckpt_file();
72 
73   // Get the epoch of latest checkpoint in this node.
latest_ckpt_epoch()74   int latest_ckpt_epoch() const { return latest_ckpt_epoch_; }
75   // Get the step of latest checkpoint in this node.
latest_ckpt_step()76   int latest_ckpt_step() const { return latest_ckpt_step_; }
77 
78   // Set whether need to reset training process or not, if true, all training process need to rollback the same step of
79   // latest checkpoint, including loading checkpoint and reset the minddata.
set_need_reset(bool need_reset)80   void set_need_reset(bool need_reset) { need_reset_ = need_reset; }
81   // Get whether need to reset training process or not.
need_reset()82   bool need_reset() const { return need_reset_; }
83 
84   // Set whether need to sync the weight of model to device.
set_need_sync_weight_to_device(bool need_sync_weight_to_device)85   void set_need_sync_weight_to_device(bool need_sync_weight_to_device) {
86     need_sync_weight_to_device_ = need_sync_weight_to_device;
87   }
88   // Get whether need to sync the weight of model to device or not.
need_sync_weight_to_device()89   bool need_sync_weight_to_device() const { return need_sync_weight_to_device_; }
90 
91   // Set global rank id.
set_global_rank_id(uint32_t global_rank_id)92   void set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
93   // Set global rank size.
set_global_rank_size(uint32_t global_rank_size)94   void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
95 
96   // Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
97   // some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
98   // successful checkpoint in each training process, and then take the minimum value as the final reset position.
99   void ObtainGlobalLatestCkptInfo();
100 
101   // Get the persistent json file pointer.
102   const std::shared_ptr<storage::JsonUtils> &persistent_json();
103 
104  private:
105   inline static std::shared_ptr<RecoveryContext> instance_{};
106 
107   RecoveryContext() = default;
108   DISABLE_COPY_AND_ASSIGN(RecoveryContext);
109 
110   // Initialize recovery context.
111   void Initialize();
112 
113   // Create config json file, used to persist node info of cluster.
114   void CreateConfigFile(const std::string &config_file_path);
115 
116   // Create persitent json file, used to persist recovery config of Worker, such as ckpt path.
117   void CreatePersistentFile();
118 
119   // Obtain the step corresponding to the local latest checkpoint in each training process.
120   void ObtainLocalLatestCkptInfo();
121 
122   // Parse latest epoch and step info from all latest checkpoints info allgather from other workers.
123   void ParseLatestCkptInfo(const std::vector<int> &recv_buffer);
124 
125   // Whether enable recovery or not, set by environment variable 'MS_ENABLE_RECOVERY'.
126   bool enable_recovery_{false};
127 
128   // The persistent directory, set by environment variable 'MS_RECOVERY_PATH'.
129   std::string recovery_path_;
130 
131   // The interval to persist model, default value: 30 second. set by environment variable 'MS_RECOVERY_INTERVAL'.
132   int recovery_interval_{30};
133 
134   // Local checkpoint file list.
135   std::vector<std::string> ckpt_files_;
136   // The file name of latest checkpoint.
137   std::string latest_ckpt_file_;
138   // The epoch of latest checkpoint.
139   int latest_ckpt_epoch_{-1};
140   // The step of latest checkpoint.
141   int latest_ckpt_step_{-1};
142 
143   // Node role in cluster, could be 'MS_WORKER', 'MS_SERVER' or 'MS_SCHED'.
144   std::string node_role_;
145 
146   // The global rank id of this process. Normally this range is 0 to `global_rank_size_ - 1`.
147   uint32_t global_rank_id_{0};
148   // The global rank size.
149   uint32_t global_rank_size_{0};
150 
151   // Whether need to reset training process or not.
152   bool need_reset_{false};
153 
154   // Whether need to sync the weight of model to device, this value needs to be set to true when python layer
155   // performs load checkpoint.
156   bool need_sync_weight_to_device_{false};
157 
158   // Whether the recovery context is already initialized.
159   bool initialized_{false};
160 
161   std::mutex create_persist_json_mtx_;
162   // The persitent json file util, used to persist recovery config.
163   std::shared_ptr<storage::JsonUtils> persistent_json_;
164 };
165 }  // namespace recovery
166 }  // namespace distributed
167 }  // namespace mindspore
168 #endif  // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
169