• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include "ps/core/communicator/communicator_base.h"
24 #include "fl/server/common.h"
25 #include "fl/server/round.h"
26 #include "fl/server/local_meta_store.h"
27 #include "fl/server/iteration_metrics.h"
28 
29 namespace mindspore {
30 namespace fl {
31 namespace server {
32 enum class IterationState {
33   // This iteration is still in process.
34   kRunning,
35   // This iteration is completed and the next iteration is not started yet.
36   kCompleted
37 };
38 
39 // The time duration between retrying when sending prepare for next iteration request failed.
40 constexpr uint32_t kRetryDurationForPrepareForNextIter = 500;
41 
42 class IterationMetrics;
43 // In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
44 // Rounds, only after all the rounds are finished, this iteration is considered as completed.
45 class Iteration {
46  public:
GetInstance()47   static Iteration &GetInstance() {
48     static Iteration instance;
49     return instance;
50   }
51 
52   // Register callbacks for other servers to synchronize iteration information from leader server.
53   void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
54 
55   // Register event callbacks for iteration state synchronization.
56   void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
57 
58   // Add a round for the iteration. This method will be called multiple times for each round.
59   void AddRound(const std::shared_ptr<Round> &round);
60 
61   // Initialize all the rounds in the iteration.
62   void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
63                   const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
64 
65   // Release all the round objects in Iteration instance. Used for reinitializing round and round kernels.
66   void ClearRounds();
67 
68   // Notify move_to_next_thread_ to move to next iteration.
69   void NotifyNext(bool is_last_iter_valid, const std::string &reason);
70 
71   // This method will control servers to proceed to next iteration.
72   // There's communication between leader and follower servers in this method.
73   // The server moves to next iteration only after the last round finishes or the time expires.
74   void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
75 
76   // Set current iteration state to running and trigger events about kIterationRunning.
77   void SetIterationRunning();
78 
79   // Set current iteration state to completed and trigger the event about kIterationCompleted.
80   void SetIterationCompleted();
81 
82   // The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
83   // completed.
84   void ScalingBarrier();
85 
86   // Reinitialize rounds after scaling operations are done.
87   // The server number after scaling is required in some rounds.
88   bool ReInitForScaling(uint32_t server_num, uint32_t server_rank);
89 
90   // After hyper-parameters are updated, some rounds and kernels should be reinitialized.
91   bool ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &updated_rounds_config);
92 
93   const std::vector<std::shared_ptr<Round>> &rounds() const;
94 
95   bool is_last_iteration_valid() const;
96 
97   // Set the instance metrics which will be called for each iteration.
98   void set_metrics(const std::shared_ptr<IterationMetrics> &metrics);
99   void set_loss(float loss);
100   void set_accuracy(float accuracy);
101 
102   // Return state of current training job instance.
103   InstanceState instance_state() const;
104 
105   // Return whether current instance is being updated.
106   bool IsInstanceBeingUpdated() const;
107 
108   // EnableFLS/disableFLS the current training instance.
109   bool EnableServerInstance(std::string *result);
110   bool DisableServerInstance(std::string *result);
111 
112   // Finish current instance and start a new one. FLPlan could be changed in this method.
113   bool NewInstance(const nlohmann::json &new_instance_json, std::string *result);
114 
115   // Query information of current instance.
116   bool QueryInstance(std::string *result);
117 
118   // Need to wait all the rounds to finish before proceed to next iteration.
119   void WaitAllRoundsFinish() const;
120 
121   // The round kernels whose Launch method has not returned yet.
122   std::atomic_uint32_t running_round_num_;
123 
124  private:
Iteration()125   Iteration()
126       : running_round_num_(0),
127         server_node_(nullptr),
128         communicator_(nullptr),
129         iteration_state_(IterationState::kCompleted),
130         start_timestamp_(0),
131         complete_timestamp_(0),
132         iteration_loop_count_(0),
133         iteration_num_(1),
134         is_last_iteration_valid_(true),
135         move_to_next_reason_(""),
136         move_to_next_thread_running_(true),
137         pinned_iter_num_(0),
138         metrics_(nullptr),
139         instance_state_(InstanceState::kRunning),
140         is_instance_being_updated_(false),
141         loss_(0.0),
142         accuracy_(0.0),
143         joined_client_num_(0),
144         rejected_client_num_(0),
145         time_cost_(0) {
146     LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
147   }
148   ~Iteration();
149   Iteration(const Iteration &) = delete;
150   Iteration &operator=(const Iteration &) = delete;
151 
152   // The server does not need to handle the iteration events for now.
HandleIterationRunningEvent()153   void HandleIterationRunningEvent() {}
HandleIterationCompletedEvent()154   void HandleIterationCompletedEvent() {}
155 
156   // Synchronize iteration form the leader server(Rank 0).
157   bool SyncIteration(uint32_t rank);
158   void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
159 
160   // The request for moving to next iteration is not reentrant.
161   bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
162 
163   // The methods for moving to next iteration for all the servers.
164   // Step 1: follower servers notify leader server that they need to move to next iteration.
165   bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
166   void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
167 
168   // Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
169   bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
170   void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
171   // The server prepare for the next iteration. This method will switch the server to safemode.
172   void PrepareForNextIter();
173 
174   // Step 3: leader server broadcast to all follower servers to move to next iteration.
175   bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
176   void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
177   // Move to next iteration. Store last iterations model and reset all the rounds.
178   void Next(bool is_iteration_valid, const std::string &reason);
179 
180   // Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
181   bool BroadcastEndLastIterRequest(uint64_t iteration_num);
182   void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
183   // The server end the last iteration. This method will increase the iteration number and cancel the safemode.
184   void EndLastIter();
185 
186   // Drop current iteration and move to the next immediately.
187   bool ForciblyMoveToNextIteration();
188 
189   // Summarize metrics for the completed iteration, including iteration time cost, accuracy, loss, etc.
190   bool SummarizeIteration();
191 
192   // Update server's hyper-parameters according to the given serialized json(hyper_params_data).
193   bool UpdateHyperParams(const nlohmann::json &json);
194 
195   // Reinitialize rounds and round kernels.
196   bool ReInitRounds();
197 
198   std::shared_ptr<ps::core::ServerNode> server_node_;
199   std::shared_ptr<ps::core::TcpCommunicator> communicator_;
200 
201   // All the rounds in the server.
202   std::vector<std::shared_ptr<Round>> rounds_;
203 
204   // The iteration is either running or completed at any time.
205   std::mutex iteration_state_mtx_;
206   std::condition_variable iteration_state_cv_;
207   std::atomic<IterationState> iteration_state_;
208   uint64_t start_timestamp_;
209   uint64_t complete_timestamp_;
210 
211   // The count of iteration loops which are completed.
212   size_t iteration_loop_count_;
213 
214   // Server's current iteration number.
215   size_t iteration_num_;
216 
217   // Whether last iteration is successfully finished and the reason.
218   bool is_last_iteration_valid_;
219   std::string move_to_next_reason_;
220 
221   // It will be notified by rounds that the instance moves to the next iteration.
222   std::thread move_to_next_thread_;
223   std::atomic_bool move_to_next_thread_running_;
224   std::mutex next_iteration_mutex_;
225   std::condition_variable next_iteration_cv_;
226 
227   // To avoid Next method is called multiple times in one iteration, we should mark the iteration number.
228   uint64_t pinned_iter_num_;
229   std::mutex pinned_mtx_;
230 
231   std::shared_ptr<IterationMetrics> metrics_;
232 
233   // The state for current instance.
234   std::atomic<InstanceState> instance_state_;
235 
236   // Every instance is not reentrant.
237   // This flag represents whether the instance is being updated.
238   std::mutex instance_mtx_;
239   bool is_instance_being_updated_;
240 
241   // The training loss after this federated learning iteration, passed by worker.
242   float loss_;
243 
244   // The evaluation result after this federated learning iteration, passed by worker.
245   float accuracy_;
246 
247   // The number of clients which join the federated aggregation.
248   size_t joined_client_num_;
249 
250   // The number of clients which are not involved in federated aggregation.
251   size_t rejected_client_num_;
252 
253   // The time cost in millisecond for this completed iteration.
254   uint64_t time_cost_;
255 };
256 }  // namespace server
257 }  // namespace fl
258 }  // namespace mindspore
259 #endif  // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
260