• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/iteration.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include <numeric>
22 #include "fl/server/model_store.h"
23 #include "fl/server/server.h"
24 
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 class Server;
29 
~Iteration()30 Iteration::~Iteration() {
31   move_to_next_thread_running_ = false;
32   next_iteration_cv_.notify_all();
33   if (move_to_next_thread_.joinable()) {
34     move_to_next_thread_.join();
35   }
36 }
37 
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)38 void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
39   MS_EXCEPTION_IF_NULL(communicator);
40   communicator_ = communicator;
41   communicator_->RegisterMsgCallBack("syncIteration",
42                                      std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
43   communicator_->RegisterMsgCallBack(
44     "notifyLeaderToNextIter",
45     std::bind(&Iteration::HandleNotifyLeaderMoveToNextIterRequest, this, std::placeholders::_1));
46   communicator_->RegisterMsgCallBack(
47     "prepareForNextIter", std::bind(&Iteration::HandlePrepareForNextIterRequest, this, std::placeholders::_1));
48   communicator_->RegisterMsgCallBack("proceedToNextIter",
49                                      std::bind(&Iteration::HandleMoveToNextIterRequest, this, std::placeholders::_1));
50   communicator_->RegisterMsgCallBack("endLastIter",
51                                      std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
52 }
53 
RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> & server_node)54 void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
55   MS_EXCEPTION_IF_NULL(server_node);
56   server_node_ = server_node;
57   server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
58                                            std::bind(&Iteration::HandleIterationRunningEvent, this));
59   server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
60                                            std::bind(&Iteration::HandleIterationCompletedEvent, this));
61 }
62 
AddRound(const std::shared_ptr<Round> & round)63 void Iteration::AddRound(const std::shared_ptr<Round> &round) {
64   MS_EXCEPTION_IF_NULL(round);
65   rounds_.push_back(round);
66 }
67 
InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> & communicators,const TimeOutCb & timeout_cb,const FinishIterCb & finish_iteration_cb)68 void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
69                            const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
70   if (communicators.empty()) {
71     MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
72     return;
73   }
74 
75   (void)std::for_each(communicators.begin(), communicators.end(),
76                       [&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
77                         for (auto &round : rounds_) {
78                           MS_EXCEPTION_IF_NULL(round);
79                           round->Initialize(communicator, timeout_cb, finish_iteration_cb);
80                         }
81                       });
82 
83   // The time window for one iteration, which will be used in some round kernels.
84   size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
85                                                  [](size_t total, const std::shared_ptr<Round> &round) {
86                                                    MS_EXCEPTION_IF_NULL(round);
87                                                    return round->check_timeout() ? total + round->time_window() : total;
88                                                  });
89   LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
90   MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
91 
92   // Initialize the thread which will handle the signal from round kernels.
93   move_to_next_thread_ = std::thread([this]() {
94     while (move_to_next_thread_running_.load()) {
95       std::unique_lock<std::mutex> lock(next_iteration_mutex_);
96       next_iteration_cv_.wait(lock);
97       if (!move_to_next_thread_running_.load()) {
98         break;
99       }
100       MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
101     }
102   });
103   return;
104 }
105 
ClearRounds()106 void Iteration::ClearRounds() { rounds_.clear(); }
107 
NotifyNext(bool is_last_iter_valid,const std::string & reason)108 void Iteration::NotifyNext(bool is_last_iter_valid, const std::string &reason) {
109   std::unique_lock<std::mutex> lock(next_iteration_mutex_);
110   is_last_iteration_valid_ = is_last_iter_valid;
111   move_to_next_reason_ = reason;
112   next_iteration_cv_.notify_one();
113 }
114 
MoveToNextIteration(bool is_last_iter_valid,const std::string & reason)115 void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
116   MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
117                << " validation is " << is_last_iter_valid << ". Reason: " << reason;
118   if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
119     return;
120   }
121 
122   MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
123   if (server_node_->rank_id() == kLeaderServerRank) {
124     if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
125       MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
126       return;
127     }
128     if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
129       MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
130       return;
131     }
132     if (!BroadcastEndLastIterRequest(iteration_num_)) {
133       MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
134       return;
135     }
136   } else {
137     // If this server is the follower server, notify leader server to control the cluster to proceed to next iteration.
138     if (!NotifyLeaderMoveToNextIteration(is_last_iter_valid, reason)) {
139       MS_LOG(ERROR) << "Server " << server_node_->rank_id() << " notifying the leader server failed.";
140       return;
141     }
142   }
143 }
144 
SetIterationRunning()145 void Iteration::SetIterationRunning() {
146   MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
147   MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
148   if (server_node_->rank_id() == kLeaderServerRank) {
149     // This event helps worker/server to be consistent in iteration state.
150     server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
151   }
152 
153   std::unique_lock<std::mutex> lock(iteration_state_mtx_);
154   iteration_state_ = IterationState::kRunning;
155   start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
156 }
157 
SetIterationCompleted()158 void Iteration::SetIterationCompleted() {
159   MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
160   MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
161   if (server_node_->rank_id() == kLeaderServerRank) {
162     // This event helps worker/server to be consistent in iteration state.
163     server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
164   }
165 
166   std::unique_lock<std::mutex> lock(iteration_state_mtx_);
167   iteration_state_ = IterationState::kCompleted;
168   complete_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
169 }
170 
ScalingBarrier()171 void Iteration::ScalingBarrier() {
172   MS_LOG(INFO) << "Starting Iteration scaling barrier.";
173   std::unique_lock<std::mutex> lock(iteration_state_mtx_);
174   if (iteration_state_.load() != IterationState::kCompleted) {
175     iteration_state_cv_.wait(lock);
176   }
177   MS_LOG(INFO) << "Ending Iteration scaling barrier.";
178 }
179 
ReInitForScaling(uint32_t server_num,uint32_t server_rank)180 bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
181   for (auto &round : rounds_) {
182     if (!round->ReInitForScaling(server_num)) {
183       MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
184       return false;
185     }
186   }
187   if (server_rank != kLeaderServerRank) {
188     if (!SyncIteration(server_rank)) {
189       MS_LOG(ERROR) << "Synchronizing iteration failed.";
190       return false;
191     }
192   }
193   return true;
194 }
195 
ReInitForUpdatingHyperParams(const std::vector<RoundConfig> & updated_rounds_config)196 bool Iteration::ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &updated_rounds_config) {
197   for (const auto &updated_round : updated_rounds_config) {
198     for (const auto &round : rounds_) {
199       if (updated_round.name == round->name()) {
200         MS_LOG(INFO) << "Reinitialize for round " << round->name();
201         if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window)) {
202           MS_LOG(ERROR) << "Reinitializing for round " << round->name() << " failed.";
203           return false;
204         }
205       }
206     }
207   }
208   return true;
209 }
210 
rounds() const211 const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return rounds_; }
212 
is_last_iteration_valid() const213 bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
214 
set_metrics(const std::shared_ptr<IterationMetrics> & metrics)215 void Iteration::set_metrics(const std::shared_ptr<IterationMetrics> &metrics) { metrics_ = metrics; }
216 
set_loss(float loss)217 void Iteration::set_loss(float loss) { loss_ = loss; }
218 
set_accuracy(float accuracy)219 void Iteration::set_accuracy(float accuracy) { accuracy_ = accuracy; }
220 
instance_state() const221 InstanceState Iteration::instance_state() const { return instance_state_.load(); }
222 
EnableServerInstance(std::string * result)223 bool Iteration::EnableServerInstance(std::string *result) {
224   MS_ERROR_IF_NULL_W_RET_VAL(result, false);
225   // Before enabling server instance, we should judge whether this request should be handled.
226   std::unique_lock<std::mutex> lock(instance_mtx_);
227   if (is_instance_being_updated_) {
228     *result = "The instance is being updated. Please retry enabling server later.";
229     MS_LOG(WARNING) << *result;
230     return false;
231   }
232   if (instance_state_.load() == InstanceState::kFinish) {
233     *result = "The instance is completed. Please do not enabling server now.";
234     MS_LOG(WARNING) << *result;
235     return false;
236   }
237 
238   // Start enabling server instance.
239   is_instance_being_updated_ = true;
240 
241   instance_state_ = InstanceState::kRunning;
242   *result = "Enabling FL-Server succeeded.";
243   MS_LOG(INFO) << *result;
244 
245   // End enabling server instance.
246   is_instance_being_updated_ = false;
247   return true;
248 }
249 
DisableServerInstance(std::string * result)250 bool Iteration::DisableServerInstance(std::string *result) {
251   MS_ERROR_IF_NULL_W_RET_VAL(result, false);
252   // Before disabling server instance, we should judge whether this request should be handled.
253   std::unique_lock<std::mutex> lock(instance_mtx_);
254   if (is_instance_being_updated_) {
255     *result = "The instance is being updated. Please retry disabling server later.";
256     MS_LOG(WARNING) << *result;
257     return false;
258   }
259   if (instance_state_.load() == InstanceState::kFinish) {
260     *result = "The instance is completed. Please do not disabling server now.";
261     MS_LOG(WARNING) << *result;
262     return false;
263   }
264   if (instance_state_.load() == InstanceState::kDisable) {
265     *result = "Disabling FL-Server succeeded.";
266     MS_LOG(INFO) << *result;
267     return true;
268   }
269 
270   // Start disabling server instance.
271   is_instance_being_updated_ = true;
272 
273   // If instance is running, we should drop current iteration and move to the next.
274   instance_state_ = InstanceState::kDisable;
275   if (!ForciblyMoveToNextIteration()) {
276     *result = "Disabling instance failed. Can't drop current iteration and move to the next.";
277     MS_LOG(ERROR) << *result;
278     return false;
279   }
280   *result = "Disabling FL-Server succeeded.";
281   MS_LOG(INFO) << *result;
282 
283   // End disabling server instance.
284   is_instance_being_updated_ = false;
285   return true;
286 }
287 
NewInstance(const nlohmann::json & new_instance_json,std::string * result)288 bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string *result) {
289   MS_ERROR_IF_NULL_W_RET_VAL(result, false);
290   // Before new instance, we should judge whether this request should be handled.
291   std::unique_lock<std::mutex> lock(instance_mtx_);
292   if (is_instance_being_updated_) {
293     *result = "The instance is being updated. Please retry new instance later.";
294     MS_LOG(WARNING) << *result;
295     return false;
296   }
297 
298   // Start new server instance.
299   is_instance_being_updated_ = true;
300 
301   // Reset current instance.
302   instance_state_ = InstanceState::kFinish;
303   Server::GetInstance().WaitExitSafeMode();
304   WaitAllRoundsFinish();
305   MS_LOG(INFO) << "Proceed to a new instance.";
306   for (auto &round : rounds_) {
307     MS_ERROR_IF_NULL_W_RET_VAL(round, false);
308     round->Reset();
309   }
310   iteration_num_ = 1;
311   LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
312   ModelStore::GetInstance().Reset();
313   if (metrics_ != nullptr) {
314     if (!metrics_->Clear()) {
315       MS_LOG(WARNING) << "Clear metrics fil failed.";
316     }
317   }
318 
319   // Update the hyper-parameters on server and reinitialize rounds.
320   if (!UpdateHyperParams(new_instance_json)) {
321     *result = "Updating hyper-parameters failed.";
322     return false;
323   }
324   if (!ReInitRounds()) {
325     *result = "Reinitializing rounds failed.";
326     return false;
327   }
328 
329   instance_state_ = InstanceState::kRunning;
330   *result = "New FL-Server instance succeeded.";
331 
332   // End new server instance.
333   is_instance_being_updated_ = false;
334   return true;
335 }
336 
WaitAllRoundsFinish() const337 void Iteration::WaitAllRoundsFinish() const {
338   while (running_round_num_.load() != 0) {
339     std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
340   }
341 }
342 
SyncIteration(uint32_t rank)343 bool Iteration::SyncIteration(uint32_t rank) {
344   MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
345   SyncIterationRequest sync_iter_req;
346   sync_iter_req.set_rank(rank);
347 
348   std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
349   if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
350                                     &sync_iter_rsp_msg)) {
351     MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
352     return false;
353   }
354 
355   MS_ERROR_IF_NULL_W_RET_VAL(sync_iter_rsp_msg, false);
356   SyncIterationResponse sync_iter_rsp;
357   (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
358   iteration_num_ = sync_iter_rsp.iteration();
359   MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
360                << sync_iter_rsp.iteration();
361   return true;
362 }
363 
HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> & message)364 void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
365   MS_ERROR_IF_NULL_WO_RET_VAL(message);
366   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
367 
368   SyncIterationRequest sync_iter_req;
369   (void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
370   uint32_t rank = sync_iter_req.rank();
371   MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank;
372 
373   SyncIterationResponse sync_iter_rsp;
374   sync_iter_rsp.set_iteration(iteration_num_);
375   std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
376   if (!communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message)) {
377     MS_LOG(ERROR) << "Sending response failed.";
378     return;
379   }
380 }
381 
IsMoveToNextIterRequestReentrant(uint64_t iteration_num)382 bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
383   std::unique_lock<std::mutex> lock(pinned_mtx_);
384   if (pinned_iter_num_ == iteration_num) {
385     MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
386     return true;
387   }
388   pinned_iter_num_ = iteration_num;
389   return false;
390 }
391 
NotifyLeaderMoveToNextIteration(bool is_last_iter_valid,const std::string & reason)392 bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
393   MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
394   MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration.";
395   NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
396   notify_leader_to_next_iter_req.set_rank(server_node_->rank_id());
397   notify_leader_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
398   notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
399   notify_leader_to_next_iter_req.set_reason(reason);
400   if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
401                                     ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
402     MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
403     return false;
404   }
405   return true;
406 }
407 
HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)408 void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
409   MS_ERROR_IF_NULL_WO_RET_VAL(message);
410   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
411   NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
412   notify_leader_to_next_iter_rsp.set_result("success");
413   if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
414                                    notify_leader_to_next_iter_rsp.SerializeAsString().size(), message)) {
415     MS_LOG(ERROR) << "Sending response failed.";
416     return;
417   }
418 
419   NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
420   (void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
421   const auto &rank = notify_leader_to_next_iter_req.rank();
422   const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
423   const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
424   const auto &reason = notify_leader_to_next_iter_req.reason();
425   MS_LOG(INFO) << "Leader server receives NotifyLeaderMoveToNextIterRequest from rank " << rank
426                << ". Iteration number: " << iter_num << ". Reason: " << reason;
427 
428   if (IsMoveToNextIterRequestReentrant(iter_num)) {
429     return;
430   }
431 
432   if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
433     MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
434     return;
435   }
436   if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
437     MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
438     return;
439   }
440   if (!BroadcastEndLastIterRequest(iteration_num_)) {
441     MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
442     return;
443   }
444 }
445 
BroadcastPrepareForNextIterRequest(bool is_last_iter_valid,const std::string & reason)446 bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
447   MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
448   PrepareForNextIter();
449   MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
450   PrepareForNextIterRequest prepare_next_iter_req;
451   prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
452   prepare_next_iter_req.set_reason(reason);
453 
454   std::vector<uint32_t> offline_servers = {};
455   for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
456     if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
457       MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
458       offline_servers.push_back(i);
459       continue;
460     }
461   }
462 
463   // Retry sending to offline servers to notify them to prepare.
464   (void)std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
465     // Should avoid endless loop if the server communicator is stopped.
466     while (communicator_->running() &&
467            !communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
468       MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
469                       << " failed. The server has not recovered yet.";
470       std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
471     }
472     MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
473   });
474   std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
475   return true;
476 }
477 
HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)478 void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
479   MS_ERROR_IF_NULL_WO_RET_VAL(message);
480   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
481   PrepareForNextIterRequest prepare_next_iter_req;
482   (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
483   const auto &reason = prepare_next_iter_req.reason();
484   MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
485   PrepareForNextIter();
486 
487   PrepareForNextIterResponse prepare_next_iter_rsp;
488   prepare_next_iter_rsp.set_result("success");
489   if (!communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
490                                    prepare_next_iter_rsp.SerializeAsString().size(), message)) {
491     MS_LOG(ERROR) << "Sending response failed.";
492     return;
493   }
494 }
495 
PrepareForNextIter()496 void Iteration::PrepareForNextIter() {
497   MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
498   Server::GetInstance().SwitchToSafeMode();
499   WaitAllRoundsFinish();
500 }
501 
BroadcastMoveToNextIterRequest(bool is_last_iter_valid,const std::string & reason)502 bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
503   MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
504   MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number "
505                << iteration_num_;
506   MoveToNextIterRequest proceed_to_next_iter_req;
507   proceed_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
508   proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
509   proceed_to_next_iter_req.set_reason(reason);
510   for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
511     if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
512       MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
513       continue;
514     }
515   }
516 
517   Next(is_last_iter_valid, reason);
518   return true;
519 }
520 
HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)521 void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
522   MS_ERROR_IF_NULL_WO_RET_VAL(message);
523   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
524 
525   MoveToNextIterRequest proceed_to_next_iter_req;
526   (void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
527   const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
528   const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
529   const auto &reason = proceed_to_next_iter_req.reason();
530 
531   MS_LOG(INFO) << "Receive proceeding to next iteration request. This server current iteration is " << iteration_num_
532                << ". The iteration number from leader server is " << last_iter_num
533                << ". Last iteration is valid or not: " << is_last_iter_valid << ". Reason: " << reason;
534   // Synchronize the iteration number with leader server.
535   iteration_num_ = last_iter_num;
536   Next(is_last_iter_valid, reason);
537 
538   MoveToNextIterResponse proceed_to_next_iter_rsp;
539   proceed_to_next_iter_rsp.set_result("success");
540   if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
541                                    proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
542     MS_LOG(ERROR) << "Sending response failed.";
543     return;
544   }
545 }
546 
Next(bool is_iteration_valid,const std::string & reason)547 void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
548   MS_LOG(INFO) << "Prepare for next iteration.";
549   is_last_iteration_valid_ = is_iteration_valid;
550   if (is_iteration_valid) {
551     // Store the model which is successfully aggregated for this iteration.
552     const auto &model = Executor::GetInstance().GetModel();
553     ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
554     MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
555   } else {
556     // Store last iteration's model because this iteration is considered as invalid.
557     const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
558     size_t latest_iter_num = iter_to_model.rbegin()->first;
559     const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
560     ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
561     MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
562   }
563 
564   for (auto &round : rounds_) {
565     MS_ERROR_IF_NULL_WO_RET_VAL(round);
566     round->Reset();
567   }
568 }
569 
BroadcastEndLastIterRequest(uint64_t last_iter_num)570 bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
571   MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
572   MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
573   EndLastIterRequest end_last_iter_req;
574   end_last_iter_req.set_last_iter_num(last_iter_num);
575   for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
576     if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
577       MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
578       continue;
579     }
580   }
581 
582   EndLastIter();
583   return true;
584 }
585 
HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)586 void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
587   MS_ERROR_IF_NULL_WO_RET_VAL(message);
588   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
589   EndLastIterRequest end_last_iter_req;
590   (void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
591   const auto &last_iter_num = end_last_iter_req.last_iter_num();
592   // If the iteration number is not matched, return error.
593   if (last_iter_num != iteration_num_) {
594     std::string reason = "The iteration of this server " + std::to_string(server_node_->rank_id()) + " is " +
595                          std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
596     EndLastIterResponse end_last_iter_rsp;
597     end_last_iter_rsp.set_result(reason);
598     if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
599                                      end_last_iter_rsp.SerializeAsString().size(), message)) {
600       MS_LOG(ERROR) << "Sending response failed.";
601       return;
602     }
603     return;
604   }
605 
606   EndLastIter();
607 
608   EndLastIterResponse end_last_iter_rsp;
609   end_last_iter_rsp.set_result("success");
610   if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
611                                    end_last_iter_rsp.SerializeAsString().size(), message)) {
612     MS_LOG(ERROR) << "Sending response failed.";
613     return;
614   }
615 }
616 
EndLastIter()617 void Iteration::EndLastIter() {
618   MS_LOG(INFO) << "End the last iteration " << iteration_num_;
619   if (iteration_num_ == ps::PSContext::instance()->fl_iteration_num()) {
620     MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
621                  << " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
622     iteration_loop_count_++;
623     instance_state_ = InstanceState::kFinish;
624   }
625 
626   std::unique_lock<std::mutex> lock(pinned_mtx_);
627   pinned_iter_num_ = 0;
628   lock.unlock();
629 
630   SetIterationCompleted();
631   if (!SummarizeIteration()) {
632     MS_LOG(WARNING) << "Summarizing iteration data failed.";
633   }
634   iteration_num_++;
635   LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
636   Server::GetInstance().CancelSafeMode();
637   iteration_state_cv_.notify_all();
638   MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
639 }
640 
ForciblyMoveToNextIteration()641 bool Iteration::ForciblyMoveToNextIteration() {
642   NotifyNext(false, "Forcibly move to next iteration.");
643   return true;
644 }
645 
SummarizeIteration()646 bool Iteration::SummarizeIteration() {
647   // If the metrics_ is not initialized or the server is not the leader server, do not summarize.
648   if (server_node_->rank_id() != kLeaderServerRank || metrics_ == nullptr) {
649     MS_LOG(INFO) << "This server will not summarize for iteration.";
650     return true;
651   }
652 
653   metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
654   metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
655   metrics_->set_cur_iteration_num(iteration_num_);
656   metrics_->set_instance_state(instance_state_.load());
657   metrics_->set_loss(loss_);
658   metrics_->set_accuracy(accuracy_);
659   // The joined client number is equal to the threshold of updateModel.
660   size_t update_model_threshold = static_cast<size_t>(
661     std::ceil(ps::PSContext::instance()->start_fl_job_threshold() * ps::PSContext::instance()->update_model_ratio()));
662   metrics_->set_joined_client_num(update_model_threshold);
663   // The rejected client number is equal to threshold of startFLJob minus threshold of updateModel.
664   metrics_->set_rejected_client_num(ps::PSContext::instance()->start_fl_job_threshold() - update_model_threshold);
665 
666   if (complete_timestamp_ < start_timestamp_) {
667     MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
668                   << ". One of them is invalid.";
669     metrics_->set_iteration_time_cost(UINT64_MAX);
670   } else {
671     metrics_->set_iteration_time_cost(complete_timestamp_ - start_timestamp_);
672   }
673 
674   if (!metrics_->Summarize()) {
675     MS_LOG(ERROR) << "Summarizing metrics failed.";
676     return false;
677   }
678   return true;
679 }
680 
UpdateHyperParams(const nlohmann::json & json)681 bool Iteration::UpdateHyperParams(const nlohmann::json &json) {
682   for (const auto &item : json.items()) {
683     std::string key = item.key();
684     if (key == "start_fl_job_threshold") {
685       ps::PSContext::instance()->set_start_fl_job_threshold(item.value().get<uint64_t>());
686       continue;
687     }
688     if (key == "start_fl_job_time_window") {
689       ps::PSContext::instance()->set_start_fl_job_time_window(item.value().get<uint64_t>());
690       continue;
691     }
692     if (key == "update_model_ratio") {
693       ps::PSContext::instance()->set_update_model_ratio(item.value().get<float>());
694       continue;
695     }
696     if (key == "update_model_time_window") {
697       ps::PSContext::instance()->set_update_model_time_window(item.value().get<uint64_t>());
698       continue;
699     }
700     if (key == "fl_iteration_num") {
701       ps::PSContext::instance()->set_fl_iteration_num(item.value().get<uint64_t>());
702       continue;
703     }
704     if (key == "client_epoch_num") {
705       ps::PSContext::instance()->set_client_epoch_num(item.value().get<uint64_t>());
706       continue;
707     }
708     if (key == "client_batch_size") {
709       ps::PSContext::instance()->set_client_batch_size(item.value().get<uint64_t>());
710       continue;
711     }
712     if (key == "client_learning_rate") {
713       ps::PSContext::instance()->set_client_learning_rate(item.value().get<float>());
714       continue;
715     }
716   }
717   return true;
718 }
719 
ReInitRounds()720 bool Iteration::ReInitRounds() {
721   size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
722   float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
723   size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
724   uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
725   uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
726   std::vector<RoundConfig> new_round_config = {
727     {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
728     {"updateModel", true, update_model_time_window, true, update_model_threshold}};
729   if (!ReInitForUpdatingHyperParams(new_round_config)) {
730     MS_LOG(ERROR) << "Reinitializing for updating hyper-parameters failed.";
731     return false;
732   }
733 
734   size_t executor_threshold = 0;
735   const std::string &server_mode = ps::PSContext::instance()->server_mode();
736   uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
737   if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
738     executor_threshold = update_model_threshold;
739   } else if (server_mode == ps::kServerModePS) {
740     executor_threshold = worker_num;
741   } else {
742     MS_LOG(ERROR) << "Server mode " << server_mode << " is not supported.";
743     return false;
744   }
745   if (!Executor::GetInstance().ReInitForUpdatingHyperParams(executor_threshold)) {
746     MS_LOG(ERROR) << "Reinitializing executor failed.";
747     return false;
748   }
749   return true;
750 }
751 }  // namespace server
752 }  // namespace fl
753 }  // namespace mindspore
754