1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 #include "ps/core/communicator/communicator_base.h" 24 #include "fl/server/common.h" 25 #include "fl/server/round.h" 26 #include "fl/server/local_meta_store.h" 27 #include "fl/server/iteration_metrics.h" 28 29 namespace mindspore { 30 namespace fl { 31 namespace server { 32 enum class IterationState { 33 // This iteration is still in process. 34 kRunning, 35 // This iteration is completed and the next iteration is not started yet. 36 kCompleted 37 }; 38 39 // The time duration between retrying when sending prepare for next iteration request failed. 40 constexpr uint32_t kRetryDurationForPrepareForNextIter = 500; 41 42 class IterationMetrics; 43 // In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of 44 // Rounds, only after all the rounds are finished, this iteration is considered as completed. 45 class Iteration { 46 public: GetInstance()47 static Iteration &GetInstance() { 48 static Iteration instance; 49 return instance; 50 } 51 52 // Register callbacks for other servers to synchronize iteration information from leader server. 53 void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator); 54 55 // Register event callbacks for iteration state synchronization. 56 void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node); 57 58 // Add a round for the iteration. This method will be called multiple times for each round. 59 void AddRound(const std::shared_ptr<Round> &round); 60 61 // Initialize all the rounds in the iteration. 62 void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators, 63 const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); 64 65 // Release all the round objects in Iteration instance. Used for reinitializing round and round kernels. 66 void ClearRounds(); 67 68 // Notify move_to_next_thread_ to move to next iteration. 69 void NotifyNext(bool is_last_iter_valid, const std::string &reason); 70 71 // This method will control servers to proceed to next iteration. 72 // There's communication between leader and follower servers in this method. 73 // The server moves to next iteration only after the last round finishes or the time expires. 74 void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason); 75 76 // Set current iteration state to running and trigger events about kIterationRunning. 77 void SetIterationRunning(); 78 79 // Set current iteration state to completed and trigger the event about kIterationCompleted. 80 void SetIterationCompleted(); 81 82 // The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is 83 // completed. 84 void ScalingBarrier(); 85 86 // Reinitialize rounds after scaling operations are done. 87 // The server number after scaling is required in some rounds. 88 bool ReInitForScaling(uint32_t server_num, uint32_t server_rank); 89 90 // After hyper-parameters are updated, some rounds and kernels should be reinitialized. 91 bool ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &updated_rounds_config); 92 93 const std::vector<std::shared_ptr<Round>> &rounds() const; 94 95 bool is_last_iteration_valid() const; 96 97 // Set the instance metrics which will be called for each iteration. 98 void set_metrics(const std::shared_ptr<IterationMetrics> &metrics); 99 void set_loss(float loss); 100 void set_accuracy(float accuracy); 101 102 // Return state of current training job instance. 103 InstanceState instance_state() const; 104 105 // Return whether current instance is being updated. 106 bool IsInstanceBeingUpdated() const; 107 108 // EnableFLS/disableFLS the current training instance. 109 bool EnableServerInstance(std::string *result); 110 bool DisableServerInstance(std::string *result); 111 112 // Finish current instance and start a new one. FLPlan could be changed in this method. 113 bool NewInstance(const nlohmann::json &new_instance_json, std::string *result); 114 115 // Query information of current instance. 116 bool QueryInstance(std::string *result); 117 118 // Need to wait all the rounds to finish before proceed to next iteration. 119 void WaitAllRoundsFinish() const; 120 121 // The round kernels whose Launch method has not returned yet. 122 std::atomic_uint32_t running_round_num_; 123 124 private: Iteration()125 Iteration() 126 : running_round_num_(0), 127 server_node_(nullptr), 128 communicator_(nullptr), 129 iteration_state_(IterationState::kCompleted), 130 start_timestamp_(0), 131 complete_timestamp_(0), 132 iteration_loop_count_(0), 133 iteration_num_(1), 134 is_last_iteration_valid_(true), 135 move_to_next_reason_(""), 136 move_to_next_thread_running_(true), 137 pinned_iter_num_(0), 138 metrics_(nullptr), 139 instance_state_(InstanceState::kRunning), 140 is_instance_being_updated_(false), 141 loss_(0.0), 142 accuracy_(0.0), 143 joined_client_num_(0), 144 rejected_client_num_(0), 145 time_cost_(0) { 146 LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); 147 } 148 ~Iteration(); 149 Iteration(const Iteration &) = delete; 150 Iteration &operator=(const Iteration &) = delete; 151 152 // The server does not need to handle the iteration events for now. HandleIterationRunningEvent()153 void HandleIterationRunningEvent() {} HandleIterationCompletedEvent()154 void HandleIterationCompletedEvent() {} 155 156 // Synchronize iteration form the leader server(Rank 0). 157 bool SyncIteration(uint32_t rank); 158 void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 159 160 // The request for moving to next iteration is not reentrant. 161 bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num); 162 163 // The methods for moving to next iteration for all the servers. 164 // Step 1: follower servers notify leader server that they need to move to next iteration. 165 bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason); 166 void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 167 168 // Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode. 169 bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); 170 void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 171 // The server prepare for the next iteration. This method will switch the server to safemode. 172 void PrepareForNextIter(); 173 174 // Step 3: leader server broadcast to all follower servers to move to next iteration. 175 bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); 176 void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 177 // Move to next iteration. Store last iterations model and reset all the rounds. 178 void Next(bool is_iteration_valid, const std::string &reason); 179 180 // Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode. 181 bool BroadcastEndLastIterRequest(uint64_t iteration_num); 182 void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 183 // The server end the last iteration. This method will increase the iteration number and cancel the safemode. 184 void EndLastIter(); 185 186 // Drop current iteration and move to the next immediately. 187 bool ForciblyMoveToNextIteration(); 188 189 // Summarize metrics for the completed iteration, including iteration time cost, accuracy, loss, etc. 190 bool SummarizeIteration(); 191 192 // Update server's hyper-parameters according to the given serialized json(hyper_params_data). 193 bool UpdateHyperParams(const nlohmann::json &json); 194 195 // Reinitialize rounds and round kernels. 196 bool ReInitRounds(); 197 198 std::shared_ptr<ps::core::ServerNode> server_node_; 199 std::shared_ptr<ps::core::TcpCommunicator> communicator_; 200 201 // All the rounds in the server. 202 std::vector<std::shared_ptr<Round>> rounds_; 203 204 // The iteration is either running or completed at any time. 205 std::mutex iteration_state_mtx_; 206 std::condition_variable iteration_state_cv_; 207 std::atomic<IterationState> iteration_state_; 208 uint64_t start_timestamp_; 209 uint64_t complete_timestamp_; 210 211 // The count of iteration loops which are completed. 212 size_t iteration_loop_count_; 213 214 // Server's current iteration number. 215 size_t iteration_num_; 216 217 // Whether last iteration is successfully finished and the reason. 218 bool is_last_iteration_valid_; 219 std::string move_to_next_reason_; 220 221 // It will be notified by rounds that the instance moves to the next iteration. 222 std::thread move_to_next_thread_; 223 std::atomic_bool move_to_next_thread_running_; 224 std::mutex next_iteration_mutex_; 225 std::condition_variable next_iteration_cv_; 226 227 // To avoid Next method is called multiple times in one iteration, we should mark the iteration number. 228 uint64_t pinned_iter_num_; 229 std::mutex pinned_mtx_; 230 231 std::shared_ptr<IterationMetrics> metrics_; 232 233 // The state for current instance. 234 std::atomic<InstanceState> instance_state_; 235 236 // Every instance is not reentrant. 237 // This flag represents whether the instance is being updated. 238 std::mutex instance_mtx_; 239 bool is_instance_being_updated_; 240 241 // The training loss after this federated learning iteration, passed by worker. 242 float loss_; 243 244 // The evaluation result after this federated learning iteration, passed by worker. 245 float accuracy_; 246 247 // The number of clients which join the federated aggregation. 248 size_t joined_client_num_; 249 250 // The number of clients which are not involved in federated aggregation. 251 size_t rejected_client_num_; 252 253 // The time cost in millisecond for this completed iteration. 254 uint64_t time_cost_; 255 }; 256 } // namespace server 257 } // namespace fl 258 } // namespace mindspore 259 #endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_ 260